use serde::{Deserialize, Serialize}; use std::time::Duration; use crate::core::error::{LoreError, Result}; use crate::http::Client; pub struct OllamaConfig { pub base_url: String, pub model: String, pub timeout_secs: u64, } impl Default for OllamaConfig { fn default() -> Self { Self { base_url: "http://localhost:11434".to_string(), model: "nomic-embed-text".to_string(), timeout_secs: 60, } } } pub struct OllamaClient { client: Client, config: OllamaConfig, } #[derive(Serialize)] struct EmbedRequest<'a> { model: &'a str, input: Vec<&'a str>, } #[derive(Deserialize)] struct EmbedResponse { #[allow(dead_code)] model: String, embeddings: Vec>, } #[derive(Deserialize)] struct TagsResponse { models: Vec, } #[derive(Deserialize)] struct ModelInfo { name: String, } impl OllamaClient { pub fn new(config: OllamaConfig) -> Self { let client = Client::with_timeout(Duration::from_secs(config.timeout_secs)); Self { client, config } } pub async fn health_check(&self) -> Result<()> { let url = format!("{}/api/tags", self.config.base_url); let response = self.client .get(&url, &[]) .await .map_err(|e| LoreError::OllamaUnavailable { base_url: self.config.base_url.clone(), detail: Some(format!("{e:?}")), })?; let tags: TagsResponse = response.json().map_err(|e| LoreError::OllamaUnavailable { base_url: self.config.base_url.clone(), detail: Some(format!("{e:?}")), })?; let model_found = tags.models.iter().any(|m| { m.name == self.config.model || m.name.starts_with(&format!("{}:", self.config.model)) }); if !model_found { return Err(LoreError::OllamaModelNotFound { model: self.config.model.clone(), }); } Ok(()) } pub async fn embed_batch(&self, texts: &[&str]) -> Result>> { let url = format!("{}/api/embed", self.config.base_url); let request = EmbedRequest { model: &self.config.model, input: texts.to_vec(), }; let response = self .client .post_json(&url, &[], &request) .await .map_err(|e| LoreError::OllamaUnavailable { base_url: self.config.base_url.clone(), detail: Some(format!("{e:?}")), })?; if !response.is_success() { let status = response.status; let body = response.text().unwrap_or_default(); return Err(LoreError::EmbeddingFailed { document_id: 0, reason: format!("HTTP {status}: {body}"), }); } let embed_response: EmbedResponse = response.json().map_err(|e| LoreError::EmbeddingFailed { document_id: 0, reason: format!("Failed to parse embed response: {e}"), })?; Ok(embed_response.embeddings) } } pub async fn check_ollama_health(base_url: &str) -> bool { let client = Client::with_timeout(Duration::from_secs(5)); let url = format!("{base_url}/api/tags"); client.get(&url, &[]).await.is_ok_and(|r| r.is_success()) } #[cfg(test)] mod tests { use super::*; #[test] fn test_config_defaults() { let config = OllamaConfig::default(); assert_eq!(config.base_url, "http://localhost:11434"); assert_eq!(config.model, "nomic-embed-text"); assert_eq!(config.timeout_secs, 60); } #[test] fn test_health_check_model_matching() { let model = "nomic-embed-text"; let tag_name = "nomic-embed-text:latest"; assert!( tag_name == model || tag_name.starts_with(&format!("{model}:")), "should match model with tag" ); let exact_name = "nomic-embed-text"; assert!( exact_name == model || exact_name.starts_with(&format!("{model}:")), "should match exact model name" ); let wrong_model = "llama2:latest"; assert!( !(wrong_model == model || wrong_model.starts_with(&format!("{model}:"))), "should not match wrong model" ); let similar_model = "nomic-embed-text-v2:latest"; assert!( !(similar_model == model || similar_model.starts_with(&format!("{model}:"))), "should not false-positive on model name prefix" ); } #[test] fn test_embed_request_serialization() { let request = EmbedRequest { model: "nomic-embed-text", input: vec!["hello", "world"], }; let json = serde_json::to_string(&request).unwrap(); assert!(json.contains("\"model\":\"nomic-embed-text\"")); assert!(json.contains("\"input\":[\"hello\",\"world\"]")); } #[test] fn test_embed_response_deserialization() { let json = r#"{"model":"nomic-embed-text","embeddings":[[0.1,0.2,0.3],[0.4,0.5,0.6]]}"#; let response: EmbedResponse = serde_json::from_str(json).unwrap(); assert_eq!(response.embeddings.len(), 2); assert_eq!(response.embeddings[0], vec![0.1, 0.2, 0.3]); assert_eq!(response.embeddings[1], vec![0.4, 0.5, 0.6]); } }