use reqwest::Client; use serde::{Deserialize, Serialize}; use std::time::Duration; use crate::core::error::{LoreError, Result}; pub struct OllamaConfig { pub base_url: String, pub model: String, pub timeout_secs: u64, } impl Default for OllamaConfig { fn default() -> Self { Self { base_url: "http://localhost:11434".to_string(), model: "nomic-embed-text".to_string(), timeout_secs: 60, } } } pub struct OllamaClient { client: Client, config: OllamaConfig, } #[derive(Serialize)] struct EmbedRequest { model: String, input: Vec, } #[derive(Deserialize)] struct EmbedResponse { #[allow(dead_code)] model: String, embeddings: Vec>, } #[derive(Deserialize)] struct TagsResponse { models: Vec, } #[derive(Deserialize)] struct ModelInfo { name: String, } impl OllamaClient { pub fn new(config: OllamaConfig) -> Self { let client = Client::builder() .timeout(Duration::from_secs(config.timeout_secs)) .build() .expect("Failed to create HTTP client"); Self { client, config } } pub async fn health_check(&self) -> Result<()> { let url = format!("{}/api/tags", self.config.base_url); let response = self.client .get(&url) .send() .await .map_err(|e| LoreError::OllamaUnavailable { base_url: self.config.base_url.clone(), source: Some(e), })?; let tags: TagsResponse = response .json() .await .map_err(|e| LoreError::OllamaUnavailable { base_url: self.config.base_url.clone(), source: Some(e), })?; let model_found = tags .models .iter() .any(|m| m.name.starts_with(&self.config.model)); if !model_found { return Err(LoreError::OllamaModelNotFound { model: self.config.model.clone(), }); } Ok(()) } pub async fn embed_batch(&self, texts: Vec) -> Result>> { let url = format!("{}/api/embed", self.config.base_url); let request = EmbedRequest { model: self.config.model.clone(), input: texts, }; let response = self .client .post(&url) .json(&request) .send() .await .map_err(|e| LoreError::OllamaUnavailable { base_url: self.config.base_url.clone(), source: Some(e), })?; let status = response.status(); if !status.is_success() { let body = response.text().await.unwrap_or_default(); return Err(LoreError::EmbeddingFailed { document_id: 0, reason: format!("HTTP {}: {}", status, body), }); } let embed_response: EmbedResponse = response .json() .await .map_err(|e| LoreError::EmbeddingFailed { document_id: 0, reason: format!("Failed to parse embed response: {}", e), })?; Ok(embed_response.embeddings) } } pub async fn check_ollama_health(base_url: &str) -> bool { let client = Client::builder() .timeout(Duration::from_secs(5)) .build() .ok(); let Some(client) = client else { return false; }; let url = format!("{base_url}/api/tags"); client.get(&url).send().await.is_ok() } #[cfg(test)] mod tests { use super::*; #[test] fn test_config_defaults() { let config = OllamaConfig::default(); assert_eq!(config.base_url, "http://localhost:11434"); assert_eq!(config.model, "nomic-embed-text"); assert_eq!(config.timeout_secs, 60); } #[test] fn test_health_check_model_starts_with() { let model = "nomic-embed-text"; let tag_name = "nomic-embed-text:latest"; assert!(tag_name.starts_with(model)); let wrong_model = "llama2"; assert!(!tag_name.starts_with(wrong_model)); } #[test] fn test_embed_request_serialization() { let request = EmbedRequest { model: "nomic-embed-text".to_string(), input: vec!["hello".to_string(), "world".to_string()], }; let json = serde_json::to_string(&request).unwrap(); assert!(json.contains("\"model\":\"nomic-embed-text\"")); assert!(json.contains("\"input\":[\"hello\",\"world\"]")); } #[test] fn test_embed_response_deserialization() { let json = r#"{"model":"nomic-embed-text","embeddings":[[0.1,0.2,0.3],[0.4,0.5,0.6]]}"#; let response: EmbedResponse = serde_json::from_str(json).unwrap(); assert_eq!(response.embeddings.len(), 2); assert_eq!(response.embeddings[0], vec![0.1, 0.2, 0.3]); assert_eq!(response.embeddings[1], vec![0.4, 0.5, 0.6]); } }