use std::collections::{HashMap, HashSet}; use futures::future::join_all; use rusqlite::Connection; use sha2::{Digest, Sha256}; use tracing::{debug, info, instrument, warn}; use crate::core::error::Result; use crate::core::shutdown::ShutdownSignal; use crate::embedding::change_detector::{count_pending_documents, find_pending_documents}; use crate::embedding::chunk_ids::{CHUNK_ROWID_MULTIPLIER, encode_rowid}; use crate::embedding::chunking::{CHUNK_MAX_BYTES, EXPECTED_DIMS, split_into_chunks}; use crate::embedding::ollama::OllamaClient; const BATCH_SIZE: usize = 32; const DB_PAGE_SIZE: usize = 500; pub const DEFAULT_EMBED_CONCURRENCY: usize = 4; #[derive(Debug, Default)] pub struct EmbedResult { pub chunks_embedded: usize, pub docs_embedded: usize, pub failed: usize, pub skipped: usize, } struct ChunkWork { doc_id: i64, chunk_index: usize, total_chunks: usize, doc_hash: String, chunk_hash: String, text: String, } #[instrument(skip(conn, client, progress_callback, signal), fields(%model_name, items_processed, items_skipped, errors))] pub async fn embed_documents( conn: &Connection, client: &OllamaClient, model_name: &str, concurrency: usize, progress_callback: Option>, signal: &ShutdownSignal, ) -> Result { let total = count_pending_documents(conn, model_name)? as usize; let mut result = EmbedResult::default(); let mut last_id: i64 = 0; let mut processed: usize = 0; if total == 0 { return Ok(result); } info!(total, "Starting embedding pipeline"); loop { if signal.is_cancelled() { info!("Shutdown requested, stopping embedding pipeline"); break; } info!(last_id, "Querying pending documents..."); let pending = find_pending_documents(conn, DB_PAGE_SIZE, last_id, model_name)?; if pending.is_empty() { break; } info!( count = pending.len(), "Found pending documents, starting page" ); conn.execute_batch("SAVEPOINT embed_page")?; let page_result = embed_page( conn, client, model_name, concurrency, &pending, &mut result, &mut last_id, &mut processed, total, &progress_callback, signal, ) .await; match page_result { Ok(()) if signal.is_cancelled() => { let _ = conn.execute_batch("ROLLBACK TO embed_page; RELEASE embed_page"); info!("Rolled back incomplete page to preserve data integrity"); } Ok(()) => { conn.execute_batch("RELEASE embed_page")?; let _ = conn.execute_batch("PRAGMA wal_checkpoint(PASSIVE)"); info!( chunks_embedded = result.chunks_embedded, failed = result.failed, skipped = result.skipped, total, "Page complete" ); } Err(e) => { let _ = conn.execute_batch("ROLLBACK TO embed_page; RELEASE embed_page"); return Err(e); } } } info!( chunks_embedded = result.chunks_embedded, failed = result.failed, skipped = result.skipped, "Embedding pipeline complete" ); tracing::Span::current().record("items_processed", result.chunks_embedded); tracing::Span::current().record("items_skipped", result.skipped); tracing::Span::current().record("errors", result.failed); Ok(result) } #[allow(clippy::too_many_arguments)] async fn embed_page( conn: &Connection, client: &OllamaClient, model_name: &str, concurrency: usize, pending: &[crate::embedding::change_detector::PendingDocument], result: &mut EmbedResult, last_id: &mut i64, processed: &mut usize, total: usize, progress_callback: &Option>, signal: &ShutdownSignal, ) -> Result<()> { let mut all_chunks: Vec = Vec::with_capacity(pending.len() * 3); let mut page_normal_docs: usize = 0; let mut chunks_needed: HashMap = HashMap::with_capacity(pending.len()); let mut chunks_stored: HashMap = HashMap::with_capacity(pending.len()); debug!(count = pending.len(), "Starting chunking loop"); for doc in pending { *last_id = doc.document_id; if doc.content_text.is_empty() { record_embedding_error( conn, doc.document_id, 0, &doc.content_hash, "empty", model_name, "Document has empty content", )?; result.skipped += 1; *processed += 1; continue; } if page_normal_docs.is_multiple_of(50) { debug!( doc_id = doc.document_id, doc_num = page_normal_docs, content_bytes = doc.content_text.len(), "Chunking document" ); } if page_normal_docs.is_multiple_of(100) { info!( doc_id = doc.document_id, content_bytes = doc.content_text.len(), docs_so_far = page_normal_docs, "Chunking document" ); } let chunks = split_into_chunks(&doc.content_text); debug!( doc_id = doc.document_id, chunk_count = chunks.len(), "Chunked" ); let total_chunks = chunks.len(); if total_chunks as i64 > CHUNK_ROWID_MULTIPLIER { warn!( doc_id = doc.document_id, chunk_count = total_chunks, max = CHUNK_ROWID_MULTIPLIER, "Document produces too many chunks, skipping to prevent rowid collision" ); record_embedding_error( conn, doc.document_id, 0, &doc.content_hash, "overflow-sentinel", model_name, &format!( "Document produces {} chunks, exceeding max {}", total_chunks, CHUNK_ROWID_MULTIPLIER ), )?; result.skipped += 1; *processed += 1; if let Some(cb) = progress_callback { cb(*processed, total); } continue; } chunks_needed.insert(doc.document_id, total_chunks); for (chunk_index, text) in chunks { all_chunks.push(ChunkWork { doc_id: doc.document_id, chunk_index, total_chunks, doc_hash: doc.content_hash.clone(), chunk_hash: sha256_hash(&text), text, }); } page_normal_docs += 1; } debug!(total_chunks = all_chunks.len(), "Chunking loop done"); let mut cleared_docs: HashSet = HashSet::with_capacity(pending.len()); let mut embed_buf: Vec = Vec::with_capacity(EXPECTED_DIMS * 4); // Split chunks into batches, then process batches in concurrent groups let batches: Vec<&[ChunkWork]> = all_chunks.chunks(BATCH_SIZE).collect(); debug!( batches = batches.len(), concurrency, "Starting Ollama requests" ); info!( chunks = all_chunks.len(), batches = batches.len(), docs = page_normal_docs, "Chunking complete, starting Ollama requests" ); info!( batches = batches.len(), concurrency, "About to start Ollama request loop" ); for (group_idx, concurrent_group) in batches.chunks(concurrency).enumerate() { debug!(group_idx, "Starting concurrent group"); if signal.is_cancelled() { info!("Shutdown requested during embedding, stopping mid-page"); break; } // Phase 1: Collect texts (must outlive the futures) let batch_texts: Vec> = concurrent_group .iter() .map(|batch| batch.iter().map(|c| c.text.as_str()).collect()) .collect(); // Phase 2: Fire concurrent HTTP requests to Ollama let futures: Vec<_> = batch_texts .iter() .map(|texts| client.embed_batch(texts)) .collect(); let api_results = join_all(futures).await; debug!( group_idx, results = api_results.len(), "Ollama group complete" ); // Phase 3: Serial DB writes for each batch result for (batch, api_result) in concurrent_group.iter().zip(api_results) { match api_result { Ok(embeddings) => { for (i, embedding) in embeddings.iter().enumerate() { if i >= batch.len() { break; } let chunk = &batch[i]; if embedding.len() != EXPECTED_DIMS { warn!( doc_id = chunk.doc_id, chunk_index = chunk.chunk_index, got_dims = embedding.len(), expected = EXPECTED_DIMS, "Dimension mismatch, skipping" ); record_embedding_error( conn, chunk.doc_id, chunk.chunk_index, &chunk.doc_hash, &chunk.chunk_hash, model_name, &format!( "Dimension mismatch: got {}, expected {}", embedding.len(), EXPECTED_DIMS ), )?; result.failed += 1; continue; } if !cleared_docs.contains(&chunk.doc_id) { clear_document_embeddings(conn, chunk.doc_id)?; cleared_docs.insert(chunk.doc_id); } store_embedding( conn, chunk.doc_id, chunk.chunk_index, &chunk.doc_hash, &chunk.chunk_hash, model_name, embedding, chunk.total_chunks, &mut embed_buf, )?; result.chunks_embedded += 1; *chunks_stored.entry(chunk.doc_id).or_insert(0) += 1; } // Record errors for chunks that Ollama silently dropped if embeddings.len() < batch.len() { warn!( returned = embeddings.len(), expected = batch.len(), "Ollama returned fewer embeddings than inputs" ); for chunk in &batch[embeddings.len()..] { record_embedding_error( conn, chunk.doc_id, chunk.chunk_index, &chunk.doc_hash, &chunk.chunk_hash, model_name, &format!( "Batch mismatch: got {} of {} embeddings", embeddings.len(), batch.len() ), )?; result.failed += 1; } } } Err(e) => { let err_str = e.to_string(); let err_lower = err_str.to_lowercase(); let is_context_error = err_lower.contains("context length") || err_lower.contains("too long") || err_lower.contains("maximum context") || err_lower.contains("token limit") || err_lower.contains("exceeds") || (err_lower.contains("413") && err_lower.contains("http")); if is_context_error && batch.len() > 1 { warn!( "Batch failed with context length error, retrying chunks individually" ); for chunk in *batch { match client.embed_batch(&[chunk.text.as_str()]).await { Ok(embeddings) if !embeddings.is_empty() && embeddings[0].len() == EXPECTED_DIMS => { if !cleared_docs.contains(&chunk.doc_id) { clear_document_embeddings(conn, chunk.doc_id)?; cleared_docs.insert(chunk.doc_id); } store_embedding( conn, chunk.doc_id, chunk.chunk_index, &chunk.doc_hash, &chunk.chunk_hash, model_name, &embeddings[0], chunk.total_chunks, &mut embed_buf, )?; result.chunks_embedded += 1; *chunks_stored.entry(chunk.doc_id).or_insert(0) += 1; } Ok(embeddings) => { let got_dims = embeddings.first().map_or(0, std::vec::Vec::len); let reason = format!( "Retry failed: got {} embeddings, first has {} dims (expected {})", embeddings.len(), got_dims, EXPECTED_DIMS ); warn!( doc_id = chunk.doc_id, chunk_index = chunk.chunk_index, %reason, "Chunk retry returned unexpected result" ); record_embedding_error( conn, chunk.doc_id, chunk.chunk_index, &chunk.doc_hash, &chunk.chunk_hash, model_name, &reason, )?; result.failed += 1; } Err(retry_err) => { let reason = format!("Retry failed: {}", retry_err); warn!( doc_id = chunk.doc_id, chunk_index = chunk.chunk_index, chunk_bytes = chunk.text.len(), error = %retry_err, "Chunk retry request failed" ); record_embedding_error( conn, chunk.doc_id, chunk.chunk_index, &chunk.doc_hash, &chunk.chunk_hash, model_name, &reason, )?; result.failed += 1; } } } } else { warn!(error = %e, "Batch embedding failed"); for chunk in *batch { record_embedding_error( conn, chunk.doc_id, chunk.chunk_index, &chunk.doc_hash, &chunk.chunk_hash, model_name, &e.to_string(), )?; result.failed += 1; } } } } } } // Count docs where all chunks were successfully stored for (doc_id, needed) in &chunks_needed { if chunks_stored.get(doc_id).copied().unwrap_or(0) == *needed { result.docs_embedded += 1; } } *processed += page_normal_docs; if let Some(cb) = progress_callback { cb(*processed, total); } Ok(()) } fn clear_document_embeddings(conn: &Connection, document_id: i64) -> Result<()> { conn.prepare_cached("DELETE FROM embedding_metadata WHERE document_id = ?1")? .execute([document_id])?; let start_rowid = encode_rowid(document_id, 0); let end_rowid = encode_rowid(document_id + 1, 0); conn.prepare_cached("DELETE FROM embeddings WHERE rowid >= ?1 AND rowid < ?2")? .execute(rusqlite::params![start_rowid, end_rowid])?; Ok(()) } #[allow(clippy::too_many_arguments)] fn store_embedding( conn: &Connection, doc_id: i64, chunk_index: usize, doc_hash: &str, chunk_hash: &str, model_name: &str, embedding: &[f32], total_chunks: usize, embed_buf: &mut Vec, ) -> Result<()> { let rowid = encode_rowid(doc_id, chunk_index as i64); embed_buf.clear(); embed_buf.reserve(embedding.len() * 4); for f in embedding { embed_buf.extend_from_slice(&f.to_le_bytes()); } conn.prepare_cached("INSERT OR REPLACE INTO embeddings (rowid, embedding) VALUES (?1, ?2)")? .execute(rusqlite::params![rowid, &embed_buf[..]])?; let chunk_count: Option = if chunk_index == 0 { Some(total_chunks as i64) } else { None }; let now = chrono::Utc::now().timestamp_millis(); conn.prepare_cached( "INSERT OR REPLACE INTO embedding_metadata (document_id, chunk_index, model, dims, document_hash, chunk_hash, created_at, attempt_count, last_error, chunk_max_bytes, chunk_count) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, 1, NULL, ?8, ?9)", )? .execute(rusqlite::params![ doc_id, chunk_index as i64, model_name, EXPECTED_DIMS as i64, doc_hash, chunk_hash, now, CHUNK_MAX_BYTES as i64, chunk_count ])?; Ok(()) } pub(crate) fn record_embedding_error( conn: &Connection, doc_id: i64, chunk_index: usize, doc_hash: &str, chunk_hash: &str, model_name: &str, error: &str, ) -> Result<()> { let now = chrono::Utc::now().timestamp_millis(); conn.prepare_cached( "INSERT INTO embedding_metadata (document_id, chunk_index, model, dims, document_hash, chunk_hash, created_at, attempt_count, last_error, last_attempt_at, chunk_max_bytes) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, 1, ?8, ?7, ?9) ON CONFLICT(document_id, chunk_index) DO UPDATE SET attempt_count = embedding_metadata.attempt_count + 1, last_error = ?8, last_attempt_at = ?7, chunk_max_bytes = ?9", )? .execute(rusqlite::params![ doc_id, chunk_index as i64, model_name, EXPECTED_DIMS as i64, doc_hash, chunk_hash, now, error, CHUNK_MAX_BYTES as i64 ])?; Ok(()) } fn sha256_hash(input: &str) -> String { let mut hasher = Sha256::new(); hasher.update(input.as_bytes()); format!("{:x}", hasher.finalize()) }