Files
gitlore/src/embedding/ollama.rs
Taylor Eernisse 45126f04a6 fix: document upsert project_id, truncation budget, and Ollama model matching
- regenerator: Include project_id in the ON CONFLICT UPDATE clause for
  document upserts. Previously, if a document moved between projects
  (e.g., during re-ingestion), the project_id would remain stale.

- truncation: Compute the omission marker ("N notes omitted") before
  checking whether first+last notes fit in the budget. The old order
  computed the marker after the budget check, meaning the marker's byte
  cost was unaccounted for and could cause over-budget output.

- ollama: Tighten model name matching to require either an exact match
  or a colon-delimited tag prefix (model == name or name starts with
  "model:"). The prior starts_with check would false-positive on
  "nomic-embed-text-v2" when looking for "nomic-embed-text". Tests
  updated to cover exact match, tagged, wrong model, and prefix
  false-positive cases.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-09 10:16:14 -05:00

219 lines
5.9 KiB
Rust

use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::time::Duration;
use tracing::warn;
use crate::core::error::{LoreError, Result};
pub struct OllamaConfig {
pub base_url: String,
pub model: String,
pub timeout_secs: u64,
}
impl Default for OllamaConfig {
fn default() -> Self {
Self {
base_url: "http://localhost:11434".to_string(),
model: "nomic-embed-text".to_string(),
timeout_secs: 60,
}
}
}
pub struct OllamaClient {
client: Client,
config: OllamaConfig,
}
#[derive(Serialize)]
struct EmbedRequest<'a> {
model: &'a str,
input: Vec<&'a str>,
}
#[derive(Deserialize)]
struct EmbedResponse {
#[allow(dead_code)]
model: String,
embeddings: Vec<Vec<f32>>,
}
#[derive(Deserialize)]
struct TagsResponse {
models: Vec<ModelInfo>,
}
#[derive(Deserialize)]
struct ModelInfo {
name: String,
}
impl OllamaClient {
pub fn new(config: OllamaConfig) -> Self {
let client = Client::builder()
.timeout(Duration::from_secs(config.timeout_secs))
.build()
.unwrap_or_else(|e| {
warn!(
error = %e,
"Failed to build configured Ollama HTTP client; falling back to default client"
);
Client::new()
});
Self { client, config }
}
pub async fn health_check(&self) -> Result<()> {
let url = format!("{}/api/tags", self.config.base_url);
let response =
self.client
.get(&url)
.send()
.await
.map_err(|e| LoreError::OllamaUnavailable {
base_url: self.config.base_url.clone(),
source: Some(e),
})?;
let tags: TagsResponse =
response
.json()
.await
.map_err(|e| LoreError::OllamaUnavailable {
base_url: self.config.base_url.clone(),
source: Some(e),
})?;
let model_found = tags.models.iter().any(|m| {
m.name == self.config.model || m.name.starts_with(&format!("{}:", self.config.model))
});
if !model_found {
return Err(LoreError::OllamaModelNotFound {
model: self.config.model.clone(),
});
}
Ok(())
}
pub async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
let url = format!("{}/api/embed", self.config.base_url);
let request = EmbedRequest {
model: &self.config.model,
input: texts.to_vec(),
};
let response = self
.client
.post(&url)
.json(&request)
.send()
.await
.map_err(|e| LoreError::OllamaUnavailable {
base_url: self.config.base_url.clone(),
source: Some(e),
})?;
let status = response.status();
if !status.is_success() {
let body = response.text().await.unwrap_or_default();
return Err(LoreError::EmbeddingFailed {
document_id: 0,
reason: format!("HTTP {}: {}", status, body),
});
}
let embed_response: EmbedResponse =
response
.json()
.await
.map_err(|e| LoreError::EmbeddingFailed {
document_id: 0,
reason: format!("Failed to parse embed response: {}", e),
})?;
Ok(embed_response.embeddings)
}
}
pub async fn check_ollama_health(base_url: &str) -> bool {
let client = Client::builder()
.timeout(Duration::from_secs(5))
.build()
.ok();
let Some(client) = client else {
return false;
};
let url = format!("{base_url}/api/tags");
client.get(&url).send().await.is_ok()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_defaults() {
let config = OllamaConfig::default();
assert_eq!(config.base_url, "http://localhost:11434");
assert_eq!(config.model, "nomic-embed-text");
assert_eq!(config.timeout_secs, 60);
}
#[test]
fn test_health_check_model_matching() {
let model = "nomic-embed-text";
let tag_name = "nomic-embed-text:latest";
assert!(
tag_name == model || tag_name.starts_with(&format!("{model}:")),
"should match model with tag"
);
let exact_name = "nomic-embed-text";
assert!(
exact_name == model || exact_name.starts_with(&format!("{model}:")),
"should match exact model name"
);
let wrong_model = "llama2:latest";
assert!(
!(wrong_model == model || wrong_model.starts_with(&format!("{model}:"))),
"should not match wrong model"
);
let similar_model = "nomic-embed-text-v2:latest";
assert!(
!(similar_model == model || similar_model.starts_with(&format!("{model}:"))),
"should not false-positive on model name prefix"
);
}
#[test]
fn test_embed_request_serialization() {
let request = EmbedRequest {
model: "nomic-embed-text",
input: vec!["hello", "world"],
};
let json = serde_json::to_string(&request).unwrap();
assert!(json.contains("\"model\":\"nomic-embed-text\""));
assert!(json.contains("\"input\":[\"hello\",\"world\"]"));
}
#[test]
fn test_embed_response_deserialization() {
let json = r#"{"model":"nomic-embed-text","embeddings":[[0.1,0.2,0.3],[0.4,0.5,0.6]]}"#;
let response: EmbedResponse = serde_json::from_str(json).unwrap();
assert_eq!(response.embeddings.len(), 2);
assert_eq!(response.embeddings[0], vec![0.1, 0.2, 0.3]);
assert_eq!(response.embeddings[1], vec![0.4, 0.5, 0.6]);
}
}