//! Async embedding pipeline: chunk documents, embed via Ollama, store in sqlite-vec. use std::collections::HashSet; use rusqlite::Connection; use sha2::{Digest, Sha256}; use tracing::{info, instrument, warn}; use crate::core::error::Result; use crate::embedding::change_detector::{count_pending_documents, find_pending_documents}; use crate::embedding::chunk_ids::{CHUNK_ROWID_MULTIPLIER, encode_rowid}; use crate::embedding::chunking::{CHUNK_MAX_BYTES, EXPECTED_DIMS, split_into_chunks}; use crate::embedding::ollama::OllamaClient; const BATCH_SIZE: usize = 32; const DB_PAGE_SIZE: usize = 500; /// Result of an embedding run. #[derive(Debug, Default)] pub struct EmbedResult { pub embedded: usize, pub failed: usize, pub skipped: usize, } /// Work item: a single chunk to embed. struct ChunkWork { doc_id: i64, chunk_index: usize, total_chunks: usize, doc_hash: String, chunk_hash: String, text: String, } /// Run the embedding pipeline: find pending documents, chunk, embed, store. /// /// Processes batches of BATCH_SIZE texts per Ollama API call. /// Uses keyset pagination over documents (DB_PAGE_SIZE per page). #[instrument(skip(conn, client, progress_callback), fields(%model_name, items_processed, items_skipped, errors))] pub async fn embed_documents( conn: &Connection, client: &OllamaClient, model_name: &str, progress_callback: Option>, ) -> Result { let total = count_pending_documents(conn, model_name)? as usize; let mut result = EmbedResult::default(); let mut last_id: i64 = 0; let mut processed: usize = 0; if total == 0 { return Ok(result); } info!(total, "Starting embedding pipeline"); loop { let pending = find_pending_documents(conn, DB_PAGE_SIZE, last_id, model_name)?; if pending.is_empty() { break; } // Wrap all DB writes for this page in a savepoint so that // clear_document_embeddings + store_embedding are atomic. If the // process crashes mid-page, the savepoint is never released and // SQLite rolls back — preventing partial document states where old // embeddings are cleared but new ones haven't been written yet. conn.execute_batch("SAVEPOINT embed_page")?; // Build chunk work items for this page let mut all_chunks: Vec = Vec::new(); let mut page_normal_docs: usize = 0; for doc in &pending { // Always advance the cursor, even for skipped docs, to avoid re-fetching last_id = doc.document_id; if doc.content_text.is_empty() { result.skipped += 1; processed += 1; continue; } let chunks = split_into_chunks(&doc.content_text); let total_chunks = chunks.len(); // Overflow guard: skip documents that produce too many chunks. // Must run BEFORE clear_document_embeddings so existing embeddings // are preserved when we skip. if total_chunks as i64 > CHUNK_ROWID_MULTIPLIER { warn!( doc_id = doc.document_id, chunk_count = total_chunks, max = CHUNK_ROWID_MULTIPLIER, "Document produces too many chunks, skipping to prevent rowid collision" ); // Record a sentinel error so the document is not re-detected as // pending on subsequent runs (prevents infinite re-processing). record_embedding_error( conn, doc.document_id, 0, // sentinel chunk_index &doc.content_hash, "overflow-sentinel", model_name, &format!( "Document produces {} chunks, exceeding max {}", total_chunks, CHUNK_ROWID_MULTIPLIER ), )?; result.skipped += 1; processed += 1; if let Some(ref cb) = progress_callback { cb(processed, total); } continue; } // Don't clear existing embeddings here — defer until the first // successful chunk embedding so that if ALL chunks for a document // fail, old embeddings survive instead of leaving zero data. for (chunk_index, text) in chunks { all_chunks.push(ChunkWork { doc_id: doc.document_id, chunk_index, total_chunks, doc_hash: doc.content_hash.clone(), chunk_hash: sha256_hash(&text), text, }); } page_normal_docs += 1; // Don't fire progress here — wait until embedding completes below. } // Track documents whose old embeddings have been cleared. // We defer clearing until the first successful chunk embedding so // that if ALL chunks for a document fail, old embeddings survive. let mut cleared_docs: HashSet = HashSet::new(); // Process chunks in batches of BATCH_SIZE for batch in all_chunks.chunks(BATCH_SIZE) { let texts: Vec = batch.iter().map(|c| c.text.clone()).collect(); match client.embed_batch(texts).await { Ok(embeddings) => { for (i, embedding) in embeddings.iter().enumerate() { if i >= batch.len() { break; } let chunk = &batch[i]; if embedding.len() != EXPECTED_DIMS { warn!( doc_id = chunk.doc_id, chunk_index = chunk.chunk_index, got_dims = embedding.len(), expected = EXPECTED_DIMS, "Dimension mismatch, skipping" ); record_embedding_error( conn, chunk.doc_id, chunk.chunk_index, &chunk.doc_hash, &chunk.chunk_hash, model_name, &format!( "Dimension mismatch: got {}, expected {}", embedding.len(), EXPECTED_DIMS ), )?; result.failed += 1; continue; } // Clear old embeddings on first successful chunk for this document if !cleared_docs.contains(&chunk.doc_id) { clear_document_embeddings(conn, chunk.doc_id)?; cleared_docs.insert(chunk.doc_id); } store_embedding( conn, chunk.doc_id, chunk.chunk_index, &chunk.doc_hash, &chunk.chunk_hash, model_name, embedding, chunk.total_chunks, )?; result.embedded += 1; } } Err(e) => { // Batch failed — retry each chunk individually so one // oversized chunk doesn't poison the entire batch. let err_str = e.to_string(); let err_lower = err_str.to_lowercase(); // Ollama error messages vary across versions. Match broadly // against known patterns to detect context-window overflow. let is_context_error = err_lower.contains("context length") || err_lower.contains("too long") || err_lower.contains("maximum context") || err_lower.contains("token limit") || err_lower.contains("exceeds") || (err_lower.contains("413") && err_lower.contains("http")); if is_context_error && batch.len() > 1 { warn!( "Batch failed with context length error, retrying chunks individually" ); for chunk in batch { match client.embed_batch(vec![chunk.text.clone()]).await { Ok(embeddings) if !embeddings.is_empty() && embeddings[0].len() == EXPECTED_DIMS => { // Clear old embeddings on first successful chunk if !cleared_docs.contains(&chunk.doc_id) { clear_document_embeddings(conn, chunk.doc_id)?; cleared_docs.insert(chunk.doc_id); } store_embedding( conn, chunk.doc_id, chunk.chunk_index, &chunk.doc_hash, &chunk.chunk_hash, model_name, &embeddings[0], chunk.total_chunks, )?; result.embedded += 1; } _ => { warn!( doc_id = chunk.doc_id, chunk_index = chunk.chunk_index, chunk_bytes = chunk.text.len(), "Chunk too large for model context window" ); record_embedding_error( conn, chunk.doc_id, chunk.chunk_index, &chunk.doc_hash, &chunk.chunk_hash, model_name, "Chunk exceeds model context window", )?; result.failed += 1; } } } } else { warn!(error = %e, "Batch embedding failed"); for chunk in batch { record_embedding_error( conn, chunk.doc_id, chunk.chunk_index, &chunk.doc_hash, &chunk.chunk_hash, model_name, &e.to_string(), )?; result.failed += 1; } } } } } // Fire progress for all normal documents after embedding completes. // This ensures progress reflects actual embedding work, not just chunking. processed += page_normal_docs; if let Some(ref cb) = progress_callback { cb(processed, total); } // Commit all DB writes for this page atomically. conn.execute_batch("RELEASE embed_page")?; } info!( embedded = result.embedded, failed = result.failed, skipped = result.skipped, "Embedding pipeline complete" ); tracing::Span::current().record("items_processed", result.embedded); tracing::Span::current().record("items_skipped", result.skipped); tracing::Span::current().record("errors", result.failed); Ok(result) } /// Clear all embeddings and metadata for a document. fn clear_document_embeddings(conn: &Connection, document_id: i64) -> Result<()> { conn.execute( "DELETE FROM embedding_metadata WHERE document_id = ?1", [document_id], )?; let start_rowid = encode_rowid(document_id, 0); let end_rowid = encode_rowid(document_id + 1, 0); conn.execute( "DELETE FROM embeddings WHERE rowid >= ?1 AND rowid < ?2", rusqlite::params![start_rowid, end_rowid], )?; Ok(()) } /// Store an embedding vector and its metadata. #[allow(clippy::too_many_arguments)] fn store_embedding( conn: &Connection, doc_id: i64, chunk_index: usize, doc_hash: &str, chunk_hash: &str, model_name: &str, embedding: &[f32], total_chunks: usize, ) -> Result<()> { let rowid = encode_rowid(doc_id, chunk_index as i64); let mut embedding_bytes = Vec::with_capacity(embedding.len() * 4); for f in embedding { embedding_bytes.extend_from_slice(&f.to_le_bytes()); } conn.execute( "INSERT OR REPLACE INTO embeddings (rowid, embedding) VALUES (?1, ?2)", rusqlite::params![rowid, embedding_bytes], )?; // Only store chunk_count on the sentinel row (chunk_index=0) let chunk_count: Option = if chunk_index == 0 { Some(total_chunks as i64) } else { None }; let now = chrono::Utc::now().timestamp_millis(); conn.execute( "INSERT OR REPLACE INTO embedding_metadata (document_id, chunk_index, model, dims, document_hash, chunk_hash, created_at, attempt_count, last_error, chunk_max_bytes, chunk_count) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, 1, NULL, ?8, ?9)", rusqlite::params![ doc_id, chunk_index as i64, model_name, EXPECTED_DIMS as i64, doc_hash, chunk_hash, now, CHUNK_MAX_BYTES as i64, chunk_count ], )?; Ok(()) } /// Record an embedding error in metadata for later retry. fn record_embedding_error( conn: &Connection, doc_id: i64, chunk_index: usize, doc_hash: &str, chunk_hash: &str, model_name: &str, error: &str, ) -> Result<()> { let now = chrono::Utc::now().timestamp_millis(); conn.execute( "INSERT INTO embedding_metadata (document_id, chunk_index, model, dims, document_hash, chunk_hash, created_at, attempt_count, last_error, last_attempt_at, chunk_max_bytes) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, 1, ?8, ?7, ?9) ON CONFLICT(document_id, chunk_index) DO UPDATE SET attempt_count = embedding_metadata.attempt_count + 1, last_error = ?8, last_attempt_at = ?7, chunk_max_bytes = ?9", rusqlite::params![ doc_id, chunk_index as i64, model_name, EXPECTED_DIMS as i64, doc_hash, chunk_hash, now, error, CHUNK_MAX_BYTES as i64 ], )?; Ok(()) } fn sha256_hash(input: &str) -> String { let mut hasher = Sha256::new(); hasher.update(input.as_bytes()); format!("{:x}", hasher.finalize()) }