use std::collections::{HashMap, HashSet}; use std::sync::Arc; use futures::future::join_all; use rusqlite::Connection; use sha2::{Digest, Sha256}; use tracing::{debug, info, instrument, warn}; use crate::core::error::Result; use crate::core::shutdown::ShutdownSignal; use crate::embedding::change_detector::{count_pending_documents, find_pending_documents}; use crate::embedding::chunks::{ CHUNK_MAX_BYTES, CHUNK_ROWID_MULTIPLIER, EXPECTED_DIMS, encode_rowid, split_into_chunks, }; use crate::embedding::ollama::OllamaClient; const BATCH_SIZE: usize = 32; const DB_PAGE_SIZE: usize = 500; pub const DEFAULT_EMBED_CONCURRENCY: usize = 4; #[derive(Debug, Default)] pub struct EmbedResult { pub chunks_embedded: usize, pub docs_embedded: usize, pub failed: usize, pub skipped: usize, } struct ChunkWork { doc_id: i64, chunk_index: usize, total_chunks: usize, doc_hash: Arc, chunk_hash: String, text: String, } #[instrument(skip(conn, client, progress_callback, signal), fields(%model_name, items_processed, items_skipped, errors))] pub async fn embed_documents( conn: &Connection, client: &OllamaClient, model_name: &str, concurrency: usize, progress_callback: Option>, signal: &ShutdownSignal, ) -> Result { let total = count_pending_documents(conn, model_name)? as usize; let mut result = EmbedResult::default(); let mut last_id: i64 = 0; let mut processed: usize = 0; if total == 0 { return Ok(result); } info!(total, "Starting embedding pipeline"); loop { if signal.is_cancelled() { info!("Shutdown requested, stopping embedding pipeline"); break; } info!(last_id, "Querying pending documents..."); let pending = find_pending_documents(conn, DB_PAGE_SIZE, last_id, model_name)?; if pending.is_empty() { break; } info!( count = pending.len(), "Found pending documents, starting page" ); conn.execute_batch("SAVEPOINT embed_page")?; let page_result = embed_page( conn, client, model_name, concurrency, &pending, &mut result, &mut last_id, &mut processed, total, &progress_callback, signal, ) .await; match page_result { Ok(()) if signal.is_cancelled() => { let _ = conn.execute_batch("ROLLBACK TO embed_page; RELEASE embed_page"); info!("Rolled back incomplete page to preserve data integrity"); } Ok(()) => { conn.execute_batch("RELEASE embed_page")?; let _ = conn.execute_batch("PRAGMA wal_checkpoint(PASSIVE)"); info!( chunks_embedded = result.chunks_embedded, failed = result.failed, skipped = result.skipped, total, "Page complete" ); } Err(e) => { let _ = conn.execute_batch("ROLLBACK TO embed_page; RELEASE embed_page"); return Err(e); } } } info!( chunks_embedded = result.chunks_embedded, failed = result.failed, skipped = result.skipped, "Embedding pipeline complete" ); tracing::Span::current().record("items_processed", result.chunks_embedded); tracing::Span::current().record("items_skipped", result.skipped); tracing::Span::current().record("errors", result.failed); Ok(result) } #[allow(clippy::too_many_arguments)] async fn embed_page( conn: &Connection, client: &OllamaClient, model_name: &str, concurrency: usize, pending: &[crate::embedding::change_detector::PendingDocument], result: &mut EmbedResult, last_id: &mut i64, processed: &mut usize, total: usize, progress_callback: &Option>, signal: &ShutdownSignal, ) -> Result<()> { let mut all_chunks: Vec = Vec::with_capacity(pending.len() * 3); let mut page_normal_docs: usize = 0; let mut chunks_needed: HashMap = HashMap::with_capacity(pending.len()); let mut chunks_stored: HashMap = HashMap::with_capacity(pending.len()); debug!(count = pending.len(), "Starting chunking loop"); for doc in pending { *last_id = doc.document_id; if doc.content_text.is_empty() { record_embedding_error( conn, doc.document_id, 0, &doc.content_hash, "empty", model_name, "Document has empty content", )?; result.skipped += 1; *processed += 1; continue; } if page_normal_docs != 0 && page_normal_docs.is_multiple_of(50) { debug!( doc_id = doc.document_id, doc_num = page_normal_docs, content_bytes = doc.content_text.len(), "Chunking document" ); } if page_normal_docs != 0 && page_normal_docs.is_multiple_of(100) { info!( doc_id = doc.document_id, content_bytes = doc.content_text.len(), docs_so_far = page_normal_docs, "Chunking document" ); } let chunks = split_into_chunks(&doc.content_text); debug!( doc_id = doc.document_id, chunk_count = chunks.len(), "Chunked" ); let total_chunks = chunks.len(); if total_chunks as i64 > CHUNK_ROWID_MULTIPLIER { warn!( doc_id = doc.document_id, chunk_count = total_chunks, max = CHUNK_ROWID_MULTIPLIER, "Document produces too many chunks, skipping to prevent rowid collision" ); record_embedding_error( conn, doc.document_id, 0, &doc.content_hash, "overflow-sentinel", model_name, &format!( "Document produces {} chunks, exceeding max {}", total_chunks, CHUNK_ROWID_MULTIPLIER ), )?; result.skipped += 1; *processed += 1; if let Some(cb) = progress_callback { cb(*processed, total); } continue; } chunks_needed.insert(doc.document_id, total_chunks); let doc_hash: Arc = Arc::from(doc.content_hash.as_str()); for (chunk_index, text) in chunks { all_chunks.push(ChunkWork { doc_id: doc.document_id, chunk_index, total_chunks, doc_hash: Arc::clone(&doc_hash), chunk_hash: sha256_hash(&text), text, }); } page_normal_docs += 1; } debug!(total_chunks = all_chunks.len(), "Chunking loop done"); let mut cleared_docs: HashSet = HashSet::with_capacity(pending.len()); let mut embed_buf: Vec = Vec::with_capacity(EXPECTED_DIMS * 4); // Split chunks into batches, then process batches in concurrent groups let batches: Vec<&[ChunkWork]> = all_chunks.chunks(BATCH_SIZE).collect(); debug!( batches = batches.len(), concurrency, "Starting Ollama requests" ); info!( chunks = all_chunks.len(), batches = batches.len(), docs = page_normal_docs, "Chunking complete, starting Ollama requests" ); info!( batches = batches.len(), concurrency, "About to start Ollama request loop" ); for (group_idx, concurrent_group) in batches.chunks(concurrency).enumerate() { debug!(group_idx, "Starting concurrent group"); if signal.is_cancelled() { info!("Shutdown requested during embedding, stopping mid-page"); break; } // Phase 1: Collect texts (must outlive the futures) let batch_texts: Vec> = concurrent_group .iter() .map(|batch| batch.iter().map(|c| c.text.as_str()).collect()) .collect(); // Phase 2: Fire concurrent HTTP requests to Ollama let futures: Vec<_> = batch_texts .iter() .map(|texts| client.embed_batch(texts)) .collect(); let api_results = join_all(futures).await; debug!( group_idx, results = api_results.len(), "Ollama group complete" ); // Phase 3: Serial DB writes for each batch result for (batch, api_result) in concurrent_group.iter().zip(api_results) { match api_result { Ok(embeddings) => { for (i, embedding) in embeddings.iter().enumerate() { if i >= batch.len() { break; } let chunk = &batch[i]; if embedding.len() != EXPECTED_DIMS { warn!( doc_id = chunk.doc_id, chunk_index = chunk.chunk_index, got_dims = embedding.len(), expected = EXPECTED_DIMS, "Dimension mismatch, skipping" ); record_embedding_error( conn, chunk.doc_id, chunk.chunk_index, &chunk.doc_hash, &chunk.chunk_hash, model_name, &format!( "Dimension mismatch: got {}, expected {}", embedding.len(), EXPECTED_DIMS ), )?; result.failed += 1; continue; } if !cleared_docs.contains(&chunk.doc_id) { clear_document_embeddings(conn, chunk.doc_id)?; cleared_docs.insert(chunk.doc_id); } store_embedding( conn, chunk.doc_id, chunk.chunk_index, &chunk.doc_hash, &chunk.chunk_hash, model_name, embedding, chunk.total_chunks, &mut embed_buf, )?; result.chunks_embedded += 1; *chunks_stored.entry(chunk.doc_id).or_insert(0) += 1; } // Record errors for chunks that Ollama silently dropped if embeddings.len() < batch.len() { warn!( returned = embeddings.len(), expected = batch.len(), "Ollama returned fewer embeddings than inputs" ); for chunk in &batch[embeddings.len()..] { record_embedding_error( conn, chunk.doc_id, chunk.chunk_index, &chunk.doc_hash, &chunk.chunk_hash, model_name, &format!( "Batch mismatch: got {} of {} embeddings", embeddings.len(), batch.len() ), )?; result.failed += 1; } } } Err(e) => { let err_str = e.to_string(); let err_lower = err_str.to_lowercase(); let is_context_error = err_lower.contains("context length") || err_lower.contains("too long") || err_lower.contains("maximum context") || err_lower.contains("token limit") || err_lower.contains("exceeds") || (err_lower.contains("413") && err_lower.contains("http")); if is_context_error && batch.len() > 1 { warn!( "Batch failed with context length error, retrying chunks individually" ); for chunk in *batch { match client.embed_batch(&[chunk.text.as_str()]).await { Ok(embeddings) if !embeddings.is_empty() && embeddings[0].len() == EXPECTED_DIMS => { if !cleared_docs.contains(&chunk.doc_id) { clear_document_embeddings(conn, chunk.doc_id)?; cleared_docs.insert(chunk.doc_id); } store_embedding( conn, chunk.doc_id, chunk.chunk_index, &chunk.doc_hash, &chunk.chunk_hash, model_name, &embeddings[0], chunk.total_chunks, &mut embed_buf, )?; result.chunks_embedded += 1; *chunks_stored.entry(chunk.doc_id).or_insert(0) += 1; } Ok(embeddings) => { let got_dims = embeddings.first().map_or(0, std::vec::Vec::len); let reason = format!( "Retry failed: got {} embeddings, first has {} dims (expected {})", embeddings.len(), got_dims, EXPECTED_DIMS ); warn!( doc_id = chunk.doc_id, chunk_index = chunk.chunk_index, %reason, "Chunk retry returned unexpected result" ); record_embedding_error( conn, chunk.doc_id, chunk.chunk_index, &chunk.doc_hash, &chunk.chunk_hash, model_name, &reason, )?; result.failed += 1; } Err(retry_err) => { let reason = format!("Retry failed: {}", retry_err); warn!( doc_id = chunk.doc_id, chunk_index = chunk.chunk_index, chunk_bytes = chunk.text.len(), error = %retry_err, "Chunk retry request failed" ); record_embedding_error( conn, chunk.doc_id, chunk.chunk_index, &chunk.doc_hash, &chunk.chunk_hash, model_name, &reason, )?; result.failed += 1; } } } } else { warn!(error = %e, "Batch embedding failed"); for chunk in *batch { record_embedding_error( conn, chunk.doc_id, chunk.chunk_index, &chunk.doc_hash, &chunk.chunk_hash, model_name, &e.to_string(), )?; result.failed += 1; } } } } } } // Count docs where all chunks were successfully stored for (doc_id, needed) in &chunks_needed { if chunks_stored.get(doc_id).copied().unwrap_or(0) == *needed { result.docs_embedded += 1; } } *processed += page_normal_docs; if let Some(cb) = progress_callback { cb(*processed, total); } Ok(()) } fn clear_document_embeddings(conn: &Connection, document_id: i64) -> Result<()> { conn.prepare_cached("DELETE FROM embedding_metadata WHERE document_id = ?1")? .execute([document_id])?; let start_rowid = encode_rowid(document_id, 0); let end_rowid = encode_rowid(document_id + 1, 0); conn.prepare_cached("DELETE FROM embeddings WHERE rowid >= ?1 AND rowid < ?2")? .execute(rusqlite::params![start_rowid, end_rowid])?; Ok(()) } #[allow(clippy::too_many_arguments)] fn store_embedding( conn: &Connection, doc_id: i64, chunk_index: usize, doc_hash: &str, chunk_hash: &str, model_name: &str, embedding: &[f32], total_chunks: usize, embed_buf: &mut Vec, ) -> Result<()> { let rowid = encode_rowid(doc_id, chunk_index as i64); embed_buf.clear(); for f in embedding { embed_buf.extend_from_slice(&f.to_le_bytes()); } conn.prepare_cached("INSERT OR REPLACE INTO embeddings (rowid, embedding) VALUES (?1, ?2)")? .execute(rusqlite::params![rowid, &embed_buf[..]])?; let chunk_count: Option = if chunk_index == 0 { Some(total_chunks as i64) } else { None }; let now = chrono::Utc::now().timestamp_millis(); conn.prepare_cached( "INSERT OR REPLACE INTO embedding_metadata (document_id, chunk_index, model, dims, document_hash, chunk_hash, created_at, attempt_count, last_error, chunk_max_bytes, chunk_count) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, 1, NULL, ?8, ?9)", )? .execute(rusqlite::params![ doc_id, chunk_index as i64, model_name, EXPECTED_DIMS as i64, doc_hash, chunk_hash, now, CHUNK_MAX_BYTES as i64, chunk_count ])?; Ok(()) } pub(crate) fn record_embedding_error( conn: &Connection, doc_id: i64, chunk_index: usize, doc_hash: &str, chunk_hash: &str, model_name: &str, error: &str, ) -> Result<()> { let now = chrono::Utc::now().timestamp_millis(); conn.prepare_cached( "INSERT INTO embedding_metadata (document_id, chunk_index, model, dims, document_hash, chunk_hash, created_at, attempt_count, last_error, last_attempt_at, chunk_max_bytes) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, 1, ?8, ?7, ?9) ON CONFLICT(document_id, chunk_index) DO UPDATE SET attempt_count = embedding_metadata.attempt_count + 1, last_error = ?8, last_attempt_at = ?7, chunk_max_bytes = ?9", )? .execute(rusqlite::params![ doc_id, chunk_index as i64, model_name, EXPECTED_DIMS as i64, doc_hash, chunk_hash, now, error, CHUNK_MAX_BYTES as i64 ])?; Ok(()) } fn sha256_hash(input: &str) -> String { let mut hasher = Sha256::new(); hasher.update(input.as_bytes()); format!("{:x}", hasher.finalize()) } #[derive(Debug, Default)] pub struct EmbedForIdsResult { pub chunks_embedded: usize, pub docs_embedded: usize, pub failed: usize, pub skipped: usize, } /// Embed only the documents with the given IDs, skipping any that are /// already embedded with matching config (model, dims, chunk size, hash). pub async fn embed_documents_by_ids( conn: &Connection, client: &OllamaClient, model_name: &str, concurrency: usize, document_ids: &[i64], signal: &ShutdownSignal, ) -> Result { let mut result = EmbedForIdsResult::default(); if document_ids.is_empty() { return Ok(result); } if signal.is_cancelled() { return Ok(result); } // Load documents for the specified IDs, filtering out already-embedded let pending = find_documents_by_ids(conn, document_ids, model_name)?; if pending.is_empty() { result.skipped = document_ids.len(); return Ok(result); } let skipped_count = document_ids.len() - pending.len(); result.skipped = skipped_count; info!( requested = document_ids.len(), pending = pending.len(), skipped = skipped_count, "Scoped embedding: processing documents by ID" ); // Use the same SAVEPOINT + embed_page pattern as the main pipeline let mut last_id: i64 = 0; let mut processed: usize = 0; let total = pending.len(); let mut page_stats = EmbedResult::default(); conn.execute_batch("SAVEPOINT embed_by_ids")?; let page_result = embed_page( conn, client, model_name, concurrency, &pending, &mut page_stats, &mut last_id, &mut processed, total, &None, signal, ) .await; match page_result { Ok(()) if signal.is_cancelled() => { let _ = conn.execute_batch("ROLLBACK TO embed_by_ids; RELEASE embed_by_ids"); info!("Rolled back scoped embed page due to cancellation"); } Ok(()) => { conn.execute_batch("RELEASE embed_by_ids")?; // Count actual results from DB let (chunks, docs) = count_embedded_results(conn, &pending)?; result.chunks_embedded = chunks; result.docs_embedded = docs; result.failed = page_stats.failed; } Err(e) => { let _ = conn.execute_batch("ROLLBACK TO embed_by_ids; RELEASE embed_by_ids"); return Err(e); } } info!( chunks_embedded = result.chunks_embedded, docs_embedded = result.docs_embedded, failed = result.failed, skipped = result.skipped, "Scoped embedding complete" ); Ok(result) } /// Load documents by specific IDs, filtering out those already embedded /// with matching config (same logic as `find_pending_documents` but scoped). fn find_documents_by_ids( conn: &Connection, document_ids: &[i64], model_name: &str, ) -> Result> { use crate::embedding::chunks::{CHUNK_MAX_BYTES, EXPECTED_DIMS}; if document_ids.is_empty() { return Ok(Vec::new()); } // Build IN clause with placeholders let placeholders: Vec = (0..document_ids.len()) .map(|i| format!("?{}", i + 1)) .collect(); let in_clause = placeholders.join(", "); let sql = format!( r#" SELECT d.id, d.content_text, d.content_hash FROM documents d LEFT JOIN embedding_metadata em ON em.document_id = d.id AND em.chunk_index = 0 WHERE d.id IN ({in_clause}) AND ( em.document_id IS NULL OR em.document_hash != d.content_hash OR em.chunk_max_bytes IS NULL OR em.chunk_max_bytes != ?{chunk_bytes_idx} OR em.model != ?{model_idx} OR em.dims != ?{dims_idx} ) ORDER BY d.id "#, in_clause = in_clause, chunk_bytes_idx = document_ids.len() + 1, model_idx = document_ids.len() + 2, dims_idx = document_ids.len() + 3, ); let mut stmt = conn.prepare(&sql)?; // Build params: document_ids... then chunk_max_bytes, model, dims let mut params: Vec> = Vec::new(); for id in document_ids { params.push(Box::new(*id)); } params.push(Box::new(CHUNK_MAX_BYTES as i64)); params.push(Box::new(model_name.to_string())); params.push(Box::new(EXPECTED_DIMS as i64)); let param_refs: Vec<&dyn rusqlite::types::ToSql> = params.iter().map(|p| p.as_ref()).collect(); let rows = stmt .query_map(param_refs.as_slice(), |row| { Ok(crate::embedding::change_detector::PendingDocument { document_id: row.get(0)?, content_text: row.get(1)?, content_hash: row.get(2)?, }) })? .collect::, _>>()?; Ok(rows) } /// Count how many chunks and complete docs were embedded for the given pending docs. fn count_embedded_results( conn: &Connection, pending: &[crate::embedding::change_detector::PendingDocument], ) -> Result<(usize, usize)> { let mut total_chunks: usize = 0; let mut total_docs: usize = 0; for doc in pending { let chunk_count: i64 = conn.query_row( "SELECT COUNT(*) FROM embedding_metadata WHERE document_id = ?1 AND last_error IS NULL", [doc.document_id], |row| row.get(0), )?; if chunk_count > 0 { total_chunks += chunk_count as usize; // Check if all expected chunks are present (chunk_count metadata on chunk_index=0) let expected: Option = conn.query_row( "SELECT chunk_count FROM embedding_metadata WHERE document_id = ?1 AND chunk_index = 0", [doc.document_id], |row| row.get(0), )?; if let Some(expected_count) = expected && chunk_count >= expected_count { total_docs += 1; } } } Ok((total_chunks, total_docs)) } #[cfg(test)] #[path = "pipeline_tests.rs"] mod tests;