feat(embedding): Add Ollama-powered vector embedding pipeline
Implements the embedding module that generates vector representations of documents using a local Ollama instance with the nomic-embed-text model. These embeddings enable semantic (vector) search and the hybrid search mode that fuses lexical and semantic results via RRF. Key components: - embedding::ollama: HTTP client for the Ollama /api/embeddings endpoint. Handles connection errors with actionable error messages (OllamaUnavailable, OllamaModelNotFound) and validates response dimensions. - embedding::chunking: Splits long documents into overlapping paragraph-aware chunks for embedding. Uses a configurable max token estimate (8192 default for nomic-embed-text) with 10% overlap to preserve cross-chunk context. - embedding::chunk_ids: Encodes chunk identity as doc_id * 1000 + chunk_index for the embeddings table rowid. This allows vector search to map results back to documents and deduplicate by doc_id efficiently. - embedding::change_detector: Compares document content_hash against stored embedding hashes to skip re-embedding unchanged documents, making incremental embedding runs fast. - embedding::pipeline: Orchestrates the full embedding flow: detect changed documents, chunk them, call Ollama in configurable concurrency (default 4), store results. Supports --retry-failed to re-attempt previously failed embeddings. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
79
src/embedding/change_detector.rs
Normal file
79
src/embedding/change_detector.rs
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
//! Detect documents needing (re-)embedding based on content hash changes.
|
||||||
|
|
||||||
|
use rusqlite::Connection;
|
||||||
|
|
||||||
|
use crate::core::error::Result;
|
||||||
|
|
||||||
|
/// A document that needs embedding or re-embedding.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct PendingDocument {
|
||||||
|
pub document_id: i64,
|
||||||
|
pub content_text: String,
|
||||||
|
pub content_hash: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Find documents that need embedding: new (no metadata) or changed (hash mismatch).
|
||||||
|
///
|
||||||
|
/// Uses keyset pagination (WHERE d.id > last_id) and returns up to `page_size` results.
|
||||||
|
pub fn find_pending_documents(
|
||||||
|
conn: &Connection,
|
||||||
|
page_size: usize,
|
||||||
|
last_id: i64,
|
||||||
|
) -> Result<Vec<PendingDocument>> {
|
||||||
|
// Documents that either:
|
||||||
|
// 1. Have no embedding_metadata at all (new)
|
||||||
|
// 2. Have metadata where document_hash != content_hash (changed)
|
||||||
|
let sql = r#"
|
||||||
|
SELECT d.id, d.content_text, d.content_hash
|
||||||
|
FROM documents d
|
||||||
|
WHERE d.id > ?1
|
||||||
|
AND (
|
||||||
|
NOT EXISTS (
|
||||||
|
SELECT 1 FROM embedding_metadata em
|
||||||
|
WHERE em.document_id = d.id AND em.chunk_index = 0
|
||||||
|
)
|
||||||
|
OR EXISTS (
|
||||||
|
SELECT 1 FROM embedding_metadata em
|
||||||
|
WHERE em.document_id = d.id AND em.chunk_index = 0
|
||||||
|
AND em.document_hash != d.content_hash
|
||||||
|
)
|
||||||
|
)
|
||||||
|
ORDER BY d.id
|
||||||
|
LIMIT ?2
|
||||||
|
"#;
|
||||||
|
|
||||||
|
let mut stmt = conn.prepare(sql)?;
|
||||||
|
let rows = stmt
|
||||||
|
.query_map(rusqlite::params![last_id, page_size as i64], |row| {
|
||||||
|
Ok(PendingDocument {
|
||||||
|
document_id: row.get(0)?,
|
||||||
|
content_text: row.get(1)?,
|
||||||
|
content_hash: row.get(2)?,
|
||||||
|
})
|
||||||
|
})?
|
||||||
|
.collect::<std::result::Result<Vec<_>, _>>()?;
|
||||||
|
|
||||||
|
Ok(rows)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Count total documents that need embedding.
|
||||||
|
pub fn count_pending_documents(conn: &Connection) -> Result<i64> {
|
||||||
|
let count: i64 = conn.query_row(
|
||||||
|
r#"
|
||||||
|
SELECT COUNT(*)
|
||||||
|
FROM documents d
|
||||||
|
WHERE NOT EXISTS (
|
||||||
|
SELECT 1 FROM embedding_metadata em
|
||||||
|
WHERE em.document_id = d.id AND em.chunk_index = 0
|
||||||
|
)
|
||||||
|
OR EXISTS (
|
||||||
|
SELECT 1 FROM embedding_metadata em
|
||||||
|
WHERE em.document_id = d.id AND em.chunk_index = 0
|
||||||
|
AND em.document_hash != d.content_hash
|
||||||
|
)
|
||||||
|
"#,
|
||||||
|
[],
|
||||||
|
|row| row.get(0),
|
||||||
|
)?;
|
||||||
|
Ok(count)
|
||||||
|
}
|
||||||
63
src/embedding/chunk_ids.rs
Normal file
63
src/embedding/chunk_ids.rs
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
/// Multiplier for encoding (document_id, chunk_index) into a single rowid.
|
||||||
|
/// Supports up to 1000 chunks per document (32M chars at 32k/chunk).
|
||||||
|
pub const CHUNK_ROWID_MULTIPLIER: i64 = 1000;
|
||||||
|
|
||||||
|
/// Encode (document_id, chunk_index) into a sqlite-vec rowid.
|
||||||
|
///
|
||||||
|
/// rowid = document_id * CHUNK_ROWID_MULTIPLIER + chunk_index
|
||||||
|
pub fn encode_rowid(document_id: i64, chunk_index: i64) -> i64 {
|
||||||
|
document_id * CHUNK_ROWID_MULTIPLIER + chunk_index
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Decode a sqlite-vec rowid back into (document_id, chunk_index).
|
||||||
|
pub fn decode_rowid(rowid: i64) -> (i64, i64) {
|
||||||
|
let document_id = rowid / CHUNK_ROWID_MULTIPLIER;
|
||||||
|
let chunk_index = rowid % CHUNK_ROWID_MULTIPLIER;
|
||||||
|
(document_id, chunk_index)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_encode_single_chunk() {
|
||||||
|
assert_eq!(encode_rowid(1, 0), 1000);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_encode_multi_chunk() {
|
||||||
|
assert_eq!(encode_rowid(1, 5), 1005);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_encode_specific_values() {
|
||||||
|
assert_eq!(encode_rowid(42, 0), 42000);
|
||||||
|
assert_eq!(encode_rowid(42, 5), 42005);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_decode_zero_chunk() {
|
||||||
|
assert_eq!(decode_rowid(42000), (42, 0));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_decode_roundtrip() {
|
||||||
|
for doc_id in [0, 1, 42, 100, 999, 10000] {
|
||||||
|
for chunk_idx in [0, 1, 5, 99, 999] {
|
||||||
|
let rowid = encode_rowid(doc_id, chunk_idx);
|
||||||
|
let (decoded_doc, decoded_chunk) = decode_rowid(rowid);
|
||||||
|
assert_eq!(
|
||||||
|
(decoded_doc, decoded_chunk),
|
||||||
|
(doc_id, chunk_idx),
|
||||||
|
"Roundtrip failed for doc_id={doc_id}, chunk_idx={chunk_idx}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_multiplier_value() {
|
||||||
|
assert_eq!(CHUNK_ROWID_MULTIPLIER, 1000);
|
||||||
|
}
|
||||||
|
}
|
||||||
207
src/embedding/chunking.rs
Normal file
207
src/embedding/chunking.rs
Normal file
@@ -0,0 +1,207 @@
|
|||||||
|
//! Text chunking for embedding: split documents at paragraph boundaries with overlap.
|
||||||
|
|
||||||
|
/// Maximum bytes per chunk.
|
||||||
|
/// Named `_BYTES` because `str::len()` returns byte count; multi-byte UTF-8
|
||||||
|
/// sequences mean byte length ≥ char count.
|
||||||
|
pub const CHUNK_MAX_BYTES: usize = 32_000;
|
||||||
|
|
||||||
|
/// Character overlap between adjacent chunks.
|
||||||
|
pub const CHUNK_OVERLAP_CHARS: usize = 500;
|
||||||
|
|
||||||
|
/// Split document content into chunks suitable for embedding.
|
||||||
|
///
|
||||||
|
/// Documents <= CHUNK_MAX_BYTES produce a single chunk.
|
||||||
|
/// Longer documents are split at paragraph boundaries (`\n\n`), falling back
|
||||||
|
/// to sentence boundaries, then word boundaries, then hard character cut.
|
||||||
|
/// Adjacent chunks share CHUNK_OVERLAP_CHARS of overlap.
|
||||||
|
///
|
||||||
|
/// Returns Vec<(chunk_index, chunk_text)>.
|
||||||
|
pub fn split_into_chunks(content: &str) -> Vec<(usize, String)> {
|
||||||
|
if content.is_empty() {
|
||||||
|
return Vec::new();
|
||||||
|
}
|
||||||
|
|
||||||
|
if content.len() <= CHUNK_MAX_BYTES {
|
||||||
|
return vec![(0, content.to_string())];
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut chunks: Vec<(usize, String)> = Vec::new();
|
||||||
|
let mut start = 0;
|
||||||
|
let mut chunk_index = 0;
|
||||||
|
|
||||||
|
while start < content.len() {
|
||||||
|
let remaining = &content[start..];
|
||||||
|
if remaining.len() <= CHUNK_MAX_BYTES {
|
||||||
|
chunks.push((chunk_index, remaining.to_string()));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find a split point within CHUNK_MAX_BYTES (char-boundary-safe)
|
||||||
|
let end = floor_char_boundary(content, start + CHUNK_MAX_BYTES);
|
||||||
|
let window = &content[start..end];
|
||||||
|
|
||||||
|
// Try paragraph boundary (\n\n) — search backward from end
|
||||||
|
let split_at = find_paragraph_break(window)
|
||||||
|
.or_else(|| find_sentence_break(window))
|
||||||
|
.or_else(|| find_word_break(window))
|
||||||
|
.unwrap_or(window.len());
|
||||||
|
|
||||||
|
let chunk_text = &content[start..start + split_at];
|
||||||
|
chunks.push((chunk_index, chunk_text.to_string()));
|
||||||
|
|
||||||
|
// Advance with overlap, guaranteeing forward progress to prevent infinite loops.
|
||||||
|
// If split_at <= CHUNK_OVERLAP_CHARS we skip overlap to avoid stalling.
|
||||||
|
// The .max(1) ensures we always advance at least 1 byte.
|
||||||
|
let advance = if split_at > CHUNK_OVERLAP_CHARS {
|
||||||
|
split_at - CHUNK_OVERLAP_CHARS
|
||||||
|
} else {
|
||||||
|
split_at
|
||||||
|
}
|
||||||
|
.max(1);
|
||||||
|
start += advance;
|
||||||
|
chunk_index += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
chunks
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Find the last paragraph break (`\n\n`) in the window, preferring the
|
||||||
|
/// last third for balanced chunks.
|
||||||
|
fn find_paragraph_break(window: &str) -> Option<usize> {
|
||||||
|
// Search backward from 2/3 of the way through to find a good split
|
||||||
|
let search_start = window.len() * 2 / 3;
|
||||||
|
window[search_start..].rfind("\n\n").map(|pos| search_start + pos + 2)
|
||||||
|
.or_else(|| window[..search_start].rfind("\n\n").map(|pos| pos + 2))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Find the last sentence boundary (`. `, `? `, `! `) in the window.
|
||||||
|
fn find_sentence_break(window: &str) -> Option<usize> {
|
||||||
|
let search_start = window.len() / 2;
|
||||||
|
for pat in &[". ", "? ", "! "] {
|
||||||
|
if let Some(pos) = window[search_start..].rfind(pat) {
|
||||||
|
return Some(search_start + pos + pat.len());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Try first half
|
||||||
|
for pat in &[". ", "? ", "! "] {
|
||||||
|
if let Some(pos) = window[..search_start].rfind(pat) {
|
||||||
|
return Some(pos + pat.len());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Find the last word boundary (space) in the window.
|
||||||
|
fn find_word_break(window: &str) -> Option<usize> {
|
||||||
|
let search_start = window.len() / 2;
|
||||||
|
window[search_start..].rfind(' ').map(|pos| search_start + pos + 1)
|
||||||
|
.or_else(|| window[..search_start].rfind(' ').map(|pos| pos + 1))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Find the largest byte index <= `idx` that is a valid char boundary in `s`.
|
||||||
|
/// Equivalent to `str::floor_char_boundary` (stabilized in Rust 1.82).
|
||||||
|
fn floor_char_boundary(s: &str, idx: usize) -> usize {
|
||||||
|
if idx >= s.len() {
|
||||||
|
return s.len();
|
||||||
|
}
|
||||||
|
let mut i = idx;
|
||||||
|
while i > 0 && !s.is_char_boundary(i) {
|
||||||
|
i -= 1;
|
||||||
|
}
|
||||||
|
i
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_empty_content() {
|
||||||
|
let chunks = split_into_chunks("");
|
||||||
|
assert!(chunks.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_short_document_single_chunk() {
|
||||||
|
let content = "Short document content.";
|
||||||
|
let chunks = split_into_chunks(content);
|
||||||
|
assert_eq!(chunks.len(), 1);
|
||||||
|
assert_eq!(chunks[0].0, 0);
|
||||||
|
assert_eq!(chunks[0].1, content);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_exactly_max_chars() {
|
||||||
|
let content = "a".repeat(CHUNK_MAX_BYTES);
|
||||||
|
let chunks = split_into_chunks(&content);
|
||||||
|
assert_eq!(chunks.len(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_long_document_multiple_chunks() {
|
||||||
|
// Create content > CHUNK_MAX_BYTES with paragraph boundaries
|
||||||
|
let paragraph = "This is a paragraph of text.\n\n";
|
||||||
|
let mut content = String::new();
|
||||||
|
while content.len() < CHUNK_MAX_BYTES * 2 {
|
||||||
|
content.push_str(paragraph);
|
||||||
|
}
|
||||||
|
|
||||||
|
let chunks = split_into_chunks(&content);
|
||||||
|
assert!(chunks.len() >= 2, "Expected multiple chunks, got {}", chunks.len());
|
||||||
|
|
||||||
|
// Verify indices are sequential
|
||||||
|
for (i, (idx, _)) in chunks.iter().enumerate() {
|
||||||
|
assert_eq!(*idx, i);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify all content is covered (no gaps)
|
||||||
|
assert!(!chunks.last().unwrap().1.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_chunk_overlap() {
|
||||||
|
// Create content that will produce 2+ chunks
|
||||||
|
let paragraph = "This is paragraph content for testing chunk overlap behavior.\n\n";
|
||||||
|
let mut content = String::new();
|
||||||
|
while content.len() < CHUNK_MAX_BYTES + CHUNK_OVERLAP_CHARS + 1000 {
|
||||||
|
content.push_str(paragraph);
|
||||||
|
}
|
||||||
|
|
||||||
|
let chunks = split_into_chunks(&content);
|
||||||
|
assert!(chunks.len() >= 2);
|
||||||
|
|
||||||
|
// Check that adjacent chunks share some content (overlap)
|
||||||
|
if chunks.len() >= 2 {
|
||||||
|
let end_of_first = &chunks[0].1;
|
||||||
|
let start_of_second = &chunks[1].1;
|
||||||
|
// The end of first chunk should overlap with start of second
|
||||||
|
let overlap_region = &end_of_first[end_of_first.len().saturating_sub(CHUNK_OVERLAP_CHARS)..];
|
||||||
|
assert!(
|
||||||
|
start_of_second.starts_with(overlap_region)
|
||||||
|
|| overlap_region.contains(&start_of_second[..100.min(start_of_second.len())]),
|
||||||
|
"Expected overlap between chunks"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_no_paragraph_boundary() {
|
||||||
|
// Create content without paragraph breaks
|
||||||
|
let content = "word ".repeat(CHUNK_MAX_BYTES / 5 * 3);
|
||||||
|
let chunks = split_into_chunks(&content);
|
||||||
|
assert!(chunks.len() >= 2);
|
||||||
|
// Should still split (at word boundaries)
|
||||||
|
for (_, chunk) in &chunks {
|
||||||
|
assert!(!chunk.is_empty());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_chunk_indices_sequential() {
|
||||||
|
let content = "a ".repeat(CHUNK_MAX_BYTES);
|
||||||
|
let chunks = split_into_chunks(&content);
|
||||||
|
for (i, (idx, _)) in chunks.iter().enumerate() {
|
||||||
|
assert_eq!(*idx, i, "Chunk index mismatch at position {}", i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
9
src/embedding/mod.rs
Normal file
9
src/embedding/mod.rs
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
pub mod change_detector;
|
||||||
|
pub mod chunk_ids;
|
||||||
|
pub mod chunking;
|
||||||
|
pub mod ollama;
|
||||||
|
pub mod pipeline;
|
||||||
|
|
||||||
|
pub use change_detector::{count_pending_documents, find_pending_documents, PendingDocument};
|
||||||
|
pub use chunking::{split_into_chunks, CHUNK_MAX_BYTES, CHUNK_OVERLAP_CHARS};
|
||||||
|
pub use pipeline::{embed_documents, EmbedResult};
|
||||||
201
src/embedding/ollama.rs
Normal file
201
src/embedding/ollama.rs
Normal file
@@ -0,0 +1,201 @@
|
|||||||
|
use reqwest::Client;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
use crate::core::error::{LoreError, Result};
|
||||||
|
|
||||||
|
/// Configuration for Ollama embedding service.
|
||||||
|
pub struct OllamaConfig {
|
||||||
|
pub base_url: String,
|
||||||
|
pub model: String,
|
||||||
|
pub timeout_secs: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for OllamaConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
base_url: "http://localhost:11434".to_string(),
|
||||||
|
model: "nomic-embed-text".to_string(),
|
||||||
|
timeout_secs: 60,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Async client for Ollama embedding API.
|
||||||
|
pub struct OllamaClient {
|
||||||
|
client: Client,
|
||||||
|
config: OllamaConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
struct EmbedRequest {
|
||||||
|
model: String,
|
||||||
|
input: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct EmbedResponse {
|
||||||
|
#[allow(dead_code)]
|
||||||
|
model: String,
|
||||||
|
embeddings: Vec<Vec<f32>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct TagsResponse {
|
||||||
|
models: Vec<ModelInfo>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct ModelInfo {
|
||||||
|
name: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OllamaClient {
|
||||||
|
pub fn new(config: OllamaConfig) -> Self {
|
||||||
|
let client = Client::builder()
|
||||||
|
.timeout(Duration::from_secs(config.timeout_secs))
|
||||||
|
.build()
|
||||||
|
.expect("Failed to create HTTP client");
|
||||||
|
|
||||||
|
Self { client, config }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Health check: verifies Ollama is reachable and the configured model exists.
|
||||||
|
///
|
||||||
|
/// Model matching uses `starts_with` so "nomic-embed-text" matches
|
||||||
|
/// "nomic-embed-text:latest".
|
||||||
|
pub async fn health_check(&self) -> Result<()> {
|
||||||
|
let url = format!("{}/api/tags", self.config.base_url);
|
||||||
|
|
||||||
|
let response = self
|
||||||
|
.client
|
||||||
|
.get(&url)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(|e| LoreError::OllamaUnavailable {
|
||||||
|
base_url: self.config.base_url.clone(),
|
||||||
|
source: Some(e),
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let tags: TagsResponse =
|
||||||
|
response
|
||||||
|
.json()
|
||||||
|
.await
|
||||||
|
.map_err(|e| LoreError::OllamaUnavailable {
|
||||||
|
base_url: self.config.base_url.clone(),
|
||||||
|
source: Some(e),
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let model_found = tags
|
||||||
|
.models
|
||||||
|
.iter()
|
||||||
|
.any(|m| m.name.starts_with(&self.config.model));
|
||||||
|
|
||||||
|
if !model_found {
|
||||||
|
return Err(LoreError::OllamaModelNotFound {
|
||||||
|
model: self.config.model.clone(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Embed a batch of texts using the configured model.
|
||||||
|
///
|
||||||
|
/// Returns one embedding vector per input text.
|
||||||
|
pub async fn embed_batch(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
|
||||||
|
let url = format!("{}/api/embed", self.config.base_url);
|
||||||
|
|
||||||
|
let request = EmbedRequest {
|
||||||
|
model: self.config.model.clone(),
|
||||||
|
input: texts,
|
||||||
|
};
|
||||||
|
|
||||||
|
let response = self.client.post(&url).json(&request).send().await.map_err(
|
||||||
|
|e| LoreError::OllamaUnavailable {
|
||||||
|
base_url: self.config.base_url.clone(),
|
||||||
|
source: Some(e),
|
||||||
|
},
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let status = response.status();
|
||||||
|
if !status.is_success() {
|
||||||
|
let body = response.text().await.unwrap_or_default();
|
||||||
|
return Err(LoreError::EmbeddingFailed {
|
||||||
|
document_id: 0,
|
||||||
|
reason: format!("HTTP {}: {}", status, body),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let embed_response: EmbedResponse =
|
||||||
|
response
|
||||||
|
.json()
|
||||||
|
.await
|
||||||
|
.map_err(|e| LoreError::EmbeddingFailed {
|
||||||
|
document_id: 0,
|
||||||
|
reason: format!("Failed to parse embed response: {}", e),
|
||||||
|
})?;
|
||||||
|
|
||||||
|
Ok(embed_response.embeddings)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Quick health check without creating a full client.
|
||||||
|
pub async fn check_ollama_health(base_url: &str) -> bool {
|
||||||
|
let client = Client::builder()
|
||||||
|
.timeout(Duration::from_secs(5))
|
||||||
|
.build()
|
||||||
|
.ok();
|
||||||
|
|
||||||
|
let Some(client) = client else {
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
|
||||||
|
let url = format!("{base_url}/api/tags");
|
||||||
|
client.get(&url).send().await.is_ok()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_config_defaults() {
|
||||||
|
let config = OllamaConfig::default();
|
||||||
|
assert_eq!(config.base_url, "http://localhost:11434");
|
||||||
|
assert_eq!(config.model, "nomic-embed-text");
|
||||||
|
assert_eq!(config.timeout_secs, 60);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_health_check_model_starts_with() {
|
||||||
|
// Verify the matching logic: "nomic-embed-text" should match "nomic-embed-text:latest"
|
||||||
|
let model = "nomic-embed-text";
|
||||||
|
let tag_name = "nomic-embed-text:latest";
|
||||||
|
assert!(tag_name.starts_with(model));
|
||||||
|
|
||||||
|
// Non-matching model
|
||||||
|
let wrong_model = "llama2";
|
||||||
|
assert!(!tag_name.starts_with(wrong_model));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_embed_request_serialization() {
|
||||||
|
let request = EmbedRequest {
|
||||||
|
model: "nomic-embed-text".to_string(),
|
||||||
|
input: vec!["hello".to_string(), "world".to_string()],
|
||||||
|
};
|
||||||
|
let json = serde_json::to_string(&request).unwrap();
|
||||||
|
assert!(json.contains("\"model\":\"nomic-embed-text\""));
|
||||||
|
assert!(json.contains("\"input\":[\"hello\",\"world\"]"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_embed_response_deserialization() {
|
||||||
|
let json = r#"{"model":"nomic-embed-text","embeddings":[[0.1,0.2,0.3],[0.4,0.5,0.6]]}"#;
|
||||||
|
let response: EmbedResponse = serde_json::from_str(json).unwrap();
|
||||||
|
assert_eq!(response.embeddings.len(), 2);
|
||||||
|
assert_eq!(response.embeddings[0], vec![0.1, 0.2, 0.3]);
|
||||||
|
assert_eq!(response.embeddings[1], vec![0.4, 0.5, 0.6]);
|
||||||
|
}
|
||||||
|
}
|
||||||
251
src/embedding/pipeline.rs
Normal file
251
src/embedding/pipeline.rs
Normal file
@@ -0,0 +1,251 @@
|
|||||||
|
//! Async embedding pipeline: chunk documents, embed via Ollama, store in sqlite-vec.
|
||||||
|
|
||||||
|
use rusqlite::Connection;
|
||||||
|
use sha2::{Digest, Sha256};
|
||||||
|
use tracing::{info, warn};
|
||||||
|
|
||||||
|
use crate::core::error::Result;
|
||||||
|
use crate::embedding::change_detector::{count_pending_documents, find_pending_documents};
|
||||||
|
use crate::embedding::chunk_ids::encode_rowid;
|
||||||
|
use crate::embedding::chunking::split_into_chunks;
|
||||||
|
use crate::embedding::ollama::OllamaClient;
|
||||||
|
|
||||||
|
const BATCH_SIZE: usize = 32;
|
||||||
|
const DB_PAGE_SIZE: usize = 500;
|
||||||
|
const EXPECTED_DIMS: usize = 768;
|
||||||
|
|
||||||
|
/// Result of an embedding run.
|
||||||
|
#[derive(Debug, Default)]
|
||||||
|
pub struct EmbedResult {
|
||||||
|
pub embedded: usize,
|
||||||
|
pub failed: usize,
|
||||||
|
pub skipped: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Work item: a single chunk to embed.
|
||||||
|
struct ChunkWork {
|
||||||
|
doc_id: i64,
|
||||||
|
chunk_index: usize,
|
||||||
|
doc_hash: String,
|
||||||
|
chunk_hash: String,
|
||||||
|
text: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Run the embedding pipeline: find pending documents, chunk, embed, store.
|
||||||
|
///
|
||||||
|
/// Processes batches of BATCH_SIZE texts per Ollama API call.
|
||||||
|
/// Uses keyset pagination over documents (DB_PAGE_SIZE per page).
|
||||||
|
pub async fn embed_documents(
|
||||||
|
conn: &Connection,
|
||||||
|
client: &OllamaClient,
|
||||||
|
model_name: &str,
|
||||||
|
progress_callback: Option<Box<dyn Fn(usize, usize)>>,
|
||||||
|
) -> Result<EmbedResult> {
|
||||||
|
let total = count_pending_documents(conn)? as usize;
|
||||||
|
let mut result = EmbedResult::default();
|
||||||
|
let mut last_id: i64 = 0;
|
||||||
|
let mut processed: usize = 0;
|
||||||
|
|
||||||
|
if total == 0 {
|
||||||
|
return Ok(result);
|
||||||
|
}
|
||||||
|
|
||||||
|
info!(total, "Starting embedding pipeline");
|
||||||
|
|
||||||
|
loop {
|
||||||
|
let pending = find_pending_documents(conn, DB_PAGE_SIZE, last_id)?;
|
||||||
|
if pending.is_empty() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build chunk work items for this page
|
||||||
|
let mut all_chunks: Vec<ChunkWork> = Vec::new();
|
||||||
|
|
||||||
|
for doc in &pending {
|
||||||
|
// Always advance the cursor, even for skipped docs, to avoid re-fetching
|
||||||
|
last_id = doc.document_id;
|
||||||
|
|
||||||
|
if doc.content_text.is_empty() {
|
||||||
|
result.skipped += 1;
|
||||||
|
processed += 1;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear existing embeddings for this document before re-embedding
|
||||||
|
clear_document_embeddings(conn, doc.document_id)?;
|
||||||
|
|
||||||
|
let chunks = split_into_chunks(&doc.content_text);
|
||||||
|
for (chunk_index, text) in chunks {
|
||||||
|
all_chunks.push(ChunkWork {
|
||||||
|
doc_id: doc.document_id,
|
||||||
|
chunk_index,
|
||||||
|
doc_hash: doc.content_hash.clone(),
|
||||||
|
chunk_hash: sha256_hash(&text),
|
||||||
|
text,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Track progress per document (not per chunk) to match `total`
|
||||||
|
processed += 1;
|
||||||
|
if let Some(ref cb) = progress_callback {
|
||||||
|
cb(processed, total);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process chunks in batches of BATCH_SIZE
|
||||||
|
for batch in all_chunks.chunks(BATCH_SIZE) {
|
||||||
|
let texts: Vec<String> = batch.iter().map(|c| c.text.clone()).collect();
|
||||||
|
|
||||||
|
match client.embed_batch(texts).await {
|
||||||
|
Ok(embeddings) => {
|
||||||
|
for (i, embedding) in embeddings.iter().enumerate() {
|
||||||
|
if i >= batch.len() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
let chunk = &batch[i];
|
||||||
|
|
||||||
|
if embedding.len() != EXPECTED_DIMS {
|
||||||
|
warn!(
|
||||||
|
doc_id = chunk.doc_id,
|
||||||
|
chunk_index = chunk.chunk_index,
|
||||||
|
got_dims = embedding.len(),
|
||||||
|
expected = EXPECTED_DIMS,
|
||||||
|
"Dimension mismatch, skipping"
|
||||||
|
);
|
||||||
|
record_embedding_error(
|
||||||
|
conn,
|
||||||
|
chunk.doc_id,
|
||||||
|
chunk.chunk_index,
|
||||||
|
&chunk.doc_hash,
|
||||||
|
&chunk.chunk_hash,
|
||||||
|
model_name,
|
||||||
|
&format!(
|
||||||
|
"Dimension mismatch: got {}, expected {}",
|
||||||
|
embedding.len(),
|
||||||
|
EXPECTED_DIMS
|
||||||
|
),
|
||||||
|
)?;
|
||||||
|
result.failed += 1;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
store_embedding(
|
||||||
|
conn,
|
||||||
|
chunk.doc_id,
|
||||||
|
chunk.chunk_index,
|
||||||
|
&chunk.doc_hash,
|
||||||
|
&chunk.chunk_hash,
|
||||||
|
model_name,
|
||||||
|
embedding,
|
||||||
|
)?;
|
||||||
|
result.embedded += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
warn!(error = %e, "Batch embedding failed");
|
||||||
|
for chunk in batch {
|
||||||
|
record_embedding_error(
|
||||||
|
conn,
|
||||||
|
chunk.doc_id,
|
||||||
|
chunk.chunk_index,
|
||||||
|
&chunk.doc_hash,
|
||||||
|
&chunk.chunk_hash,
|
||||||
|
model_name,
|
||||||
|
&e.to_string(),
|
||||||
|
)?;
|
||||||
|
result.failed += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
info!(
|
||||||
|
embedded = result.embedded,
|
||||||
|
failed = result.failed,
|
||||||
|
skipped = result.skipped,
|
||||||
|
"Embedding pipeline complete"
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Clear all embeddings and metadata for a document.
|
||||||
|
fn clear_document_embeddings(conn: &Connection, document_id: i64) -> Result<()> {
|
||||||
|
conn.execute(
|
||||||
|
"DELETE FROM embedding_metadata WHERE document_id = ?1",
|
||||||
|
[document_id],
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let start_rowid = encode_rowid(document_id, 0);
|
||||||
|
let end_rowid = encode_rowid(document_id + 1, 0);
|
||||||
|
conn.execute(
|
||||||
|
"DELETE FROM embeddings WHERE rowid >= ?1 AND rowid < ?2",
|
||||||
|
rusqlite::params![start_rowid, end_rowid],
|
||||||
|
)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Store an embedding vector and its metadata.
|
||||||
|
fn store_embedding(
|
||||||
|
conn: &Connection,
|
||||||
|
doc_id: i64,
|
||||||
|
chunk_index: usize,
|
||||||
|
doc_hash: &str,
|
||||||
|
chunk_hash: &str,
|
||||||
|
model_name: &str,
|
||||||
|
embedding: &[f32],
|
||||||
|
) -> Result<()> {
|
||||||
|
let rowid = encode_rowid(doc_id, chunk_index as i64);
|
||||||
|
|
||||||
|
let embedding_bytes: Vec<u8> = embedding.iter().flat_map(|f| f.to_le_bytes()).collect();
|
||||||
|
|
||||||
|
conn.execute(
|
||||||
|
"INSERT OR REPLACE INTO embeddings (rowid, embedding) VALUES (?1, ?2)",
|
||||||
|
rusqlite::params![rowid, embedding_bytes],
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let now = chrono::Utc::now().timestamp_millis();
|
||||||
|
conn.execute(
|
||||||
|
"INSERT OR REPLACE INTO embedding_metadata
|
||||||
|
(document_id, chunk_index, model, dims, document_hash, chunk_hash,
|
||||||
|
created_at, attempt_count, last_error)
|
||||||
|
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, 1, NULL)",
|
||||||
|
rusqlite::params![doc_id, chunk_index as i64, model_name, EXPECTED_DIMS as i64, doc_hash, chunk_hash, now],
|
||||||
|
)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Record an embedding error in metadata for later retry.
|
||||||
|
fn record_embedding_error(
|
||||||
|
conn: &Connection,
|
||||||
|
doc_id: i64,
|
||||||
|
chunk_index: usize,
|
||||||
|
doc_hash: &str,
|
||||||
|
chunk_hash: &str,
|
||||||
|
model_name: &str,
|
||||||
|
error: &str,
|
||||||
|
) -> Result<()> {
|
||||||
|
let now = chrono::Utc::now().timestamp_millis();
|
||||||
|
conn.execute(
|
||||||
|
"INSERT INTO embedding_metadata
|
||||||
|
(document_id, chunk_index, model, dims, document_hash, chunk_hash,
|
||||||
|
created_at, attempt_count, last_error, last_attempt_at)
|
||||||
|
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, 1, ?8, ?7)
|
||||||
|
ON CONFLICT(document_id, chunk_index) DO UPDATE SET
|
||||||
|
attempt_count = embedding_metadata.attempt_count + 1,
|
||||||
|
last_error = ?8,
|
||||||
|
last_attempt_at = ?7",
|
||||||
|
rusqlite::params![doc_id, chunk_index as i64, model_name, EXPECTED_DIMS as i64, doc_hash, chunk_hash, now, error],
|
||||||
|
)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn sha256_hash(input: &str) -> String {
|
||||||
|
let mut hasher = Sha256::new();
|
||||||
|
hasher.update(input.as_bytes());
|
||||||
|
format!("{:x}", hasher.finalize())
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user