use std::collections::HashSet; use rusqlite::Connection; use sha2::{Digest, Sha256}; use tracing::{info, instrument, warn}; use crate::core::error::Result; use crate::embedding::change_detector::{count_pending_documents, find_pending_documents}; use crate::embedding::chunk_ids::{CHUNK_ROWID_MULTIPLIER, encode_rowid}; use crate::embedding::chunking::{CHUNK_MAX_BYTES, EXPECTED_DIMS, split_into_chunks}; use crate::embedding::ollama::OllamaClient; const BATCH_SIZE: usize = 32; const DB_PAGE_SIZE: usize = 500; #[derive(Debug, Default)] pub struct EmbedResult { pub embedded: usize, pub failed: usize, pub skipped: usize, } struct ChunkWork { doc_id: i64, chunk_index: usize, total_chunks: usize, doc_hash: String, chunk_hash: String, text: String, } #[instrument(skip(conn, client, progress_callback), fields(%model_name, items_processed, items_skipped, errors))] pub async fn embed_documents( conn: &Connection, client: &OllamaClient, model_name: &str, progress_callback: Option>, ) -> Result { let total = count_pending_documents(conn, model_name)? as usize; let mut result = EmbedResult::default(); let mut last_id: i64 = 0; let mut processed: usize = 0; if total == 0 { return Ok(result); } info!(total, "Starting embedding pipeline"); loop { let pending = find_pending_documents(conn, DB_PAGE_SIZE, last_id, model_name)?; if pending.is_empty() { break; } conn.execute_batch("SAVEPOINT embed_page")?; let page_result = embed_page( conn, client, model_name, &pending, &mut result, &mut last_id, &mut processed, total, &progress_callback, ) .await; match page_result { Ok(()) => { conn.execute_batch("RELEASE embed_page")?; } Err(e) => { let _ = conn.execute_batch("ROLLBACK TO embed_page; RELEASE embed_page"); return Err(e); } } } info!( embedded = result.embedded, failed = result.failed, skipped = result.skipped, "Embedding pipeline complete" ); tracing::Span::current().record("items_processed", result.embedded); tracing::Span::current().record("items_skipped", result.skipped); tracing::Span::current().record("errors", result.failed); Ok(result) } #[allow(clippy::too_many_arguments)] async fn embed_page( conn: &Connection, client: &OllamaClient, model_name: &str, pending: &[crate::embedding::change_detector::PendingDocument], result: &mut EmbedResult, last_id: &mut i64, processed: &mut usize, total: usize, progress_callback: &Option>, ) -> Result<()> { let mut all_chunks: Vec = Vec::with_capacity(pending.len() * 3); let mut page_normal_docs: usize = 0; for doc in pending { *last_id = doc.document_id; if doc.content_text.is_empty() { result.skipped += 1; *processed += 1; continue; } let chunks = split_into_chunks(&doc.content_text); let total_chunks = chunks.len(); if total_chunks as i64 > CHUNK_ROWID_MULTIPLIER { warn!( doc_id = doc.document_id, chunk_count = total_chunks, max = CHUNK_ROWID_MULTIPLIER, "Document produces too many chunks, skipping to prevent rowid collision" ); record_embedding_error( conn, doc.document_id, 0, &doc.content_hash, "overflow-sentinel", model_name, &format!( "Document produces {} chunks, exceeding max {}", total_chunks, CHUNK_ROWID_MULTIPLIER ), )?; result.skipped += 1; *processed += 1; if let Some(cb) = progress_callback { cb(*processed, total); } continue; } for (chunk_index, text) in chunks { all_chunks.push(ChunkWork { doc_id: doc.document_id, chunk_index, total_chunks, doc_hash: doc.content_hash.clone(), chunk_hash: sha256_hash(&text), text, }); } page_normal_docs += 1; } let mut cleared_docs: HashSet = HashSet::with_capacity(pending.len()); for batch in all_chunks.chunks(BATCH_SIZE) { let texts: Vec<&str> = batch.iter().map(|c| c.text.as_str()).collect(); match client.embed_batch(&texts).await { Ok(embeddings) => { for (i, embedding) in embeddings.iter().enumerate() { if i >= batch.len() { break; } let chunk = &batch[i]; if embedding.len() != EXPECTED_DIMS { warn!( doc_id = chunk.doc_id, chunk_index = chunk.chunk_index, got_dims = embedding.len(), expected = EXPECTED_DIMS, "Dimension mismatch, skipping" ); record_embedding_error( conn, chunk.doc_id, chunk.chunk_index, &chunk.doc_hash, &chunk.chunk_hash, model_name, &format!( "Dimension mismatch: got {}, expected {}", embedding.len(), EXPECTED_DIMS ), )?; result.failed += 1; continue; } if !cleared_docs.contains(&chunk.doc_id) { clear_document_embeddings(conn, chunk.doc_id)?; cleared_docs.insert(chunk.doc_id); } store_embedding( conn, chunk.doc_id, chunk.chunk_index, &chunk.doc_hash, &chunk.chunk_hash, model_name, embedding, chunk.total_chunks, )?; result.embedded += 1; } } Err(e) => { let err_str = e.to_string(); let err_lower = err_str.to_lowercase(); let is_context_error = err_lower.contains("context length") || err_lower.contains("too long") || err_lower.contains("maximum context") || err_lower.contains("token limit") || err_lower.contains("exceeds") || (err_lower.contains("413") && err_lower.contains("http")); if is_context_error && batch.len() > 1 { warn!("Batch failed with context length error, retrying chunks individually"); for chunk in batch { match client.embed_batch(&[chunk.text.as_str()]).await { Ok(embeddings) if !embeddings.is_empty() && embeddings[0].len() == EXPECTED_DIMS => { if !cleared_docs.contains(&chunk.doc_id) { clear_document_embeddings(conn, chunk.doc_id)?; cleared_docs.insert(chunk.doc_id); } store_embedding( conn, chunk.doc_id, chunk.chunk_index, &chunk.doc_hash, &chunk.chunk_hash, model_name, &embeddings[0], chunk.total_chunks, )?; result.embedded += 1; } _ => { warn!( doc_id = chunk.doc_id, chunk_index = chunk.chunk_index, chunk_bytes = chunk.text.len(), "Chunk too large for model context window" ); record_embedding_error( conn, chunk.doc_id, chunk.chunk_index, &chunk.doc_hash, &chunk.chunk_hash, model_name, "Chunk exceeds model context window", )?; result.failed += 1; } } } } else { warn!(error = %e, "Batch embedding failed"); for chunk in batch { record_embedding_error( conn, chunk.doc_id, chunk.chunk_index, &chunk.doc_hash, &chunk.chunk_hash, model_name, &e.to_string(), )?; result.failed += 1; } } } } } *processed += page_normal_docs; if let Some(cb) = progress_callback { cb(*processed, total); } Ok(()) } fn clear_document_embeddings(conn: &Connection, document_id: i64) -> Result<()> { conn.execute( "DELETE FROM embedding_metadata WHERE document_id = ?1", [document_id], )?; let start_rowid = encode_rowid(document_id, 0); let end_rowid = encode_rowid(document_id + 1, 0); conn.execute( "DELETE FROM embeddings WHERE rowid >= ?1 AND rowid < ?2", rusqlite::params![start_rowid, end_rowid], )?; Ok(()) } #[allow(clippy::too_many_arguments)] fn store_embedding( conn: &Connection, doc_id: i64, chunk_index: usize, doc_hash: &str, chunk_hash: &str, model_name: &str, embedding: &[f32], total_chunks: usize, ) -> Result<()> { let rowid = encode_rowid(doc_id, chunk_index as i64); let mut embedding_bytes = Vec::with_capacity(embedding.len() * 4); for f in embedding { embedding_bytes.extend_from_slice(&f.to_le_bytes()); } conn.execute( "INSERT OR REPLACE INTO embeddings (rowid, embedding) VALUES (?1, ?2)", rusqlite::params![rowid, embedding_bytes], )?; let chunk_count: Option = if chunk_index == 0 { Some(total_chunks as i64) } else { None }; let now = chrono::Utc::now().timestamp_millis(); conn.execute( "INSERT OR REPLACE INTO embedding_metadata (document_id, chunk_index, model, dims, document_hash, chunk_hash, created_at, attempt_count, last_error, chunk_max_bytes, chunk_count) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, 1, NULL, ?8, ?9)", rusqlite::params![ doc_id, chunk_index as i64, model_name, EXPECTED_DIMS as i64, doc_hash, chunk_hash, now, CHUNK_MAX_BYTES as i64, chunk_count ], )?; Ok(()) } fn record_embedding_error( conn: &Connection, doc_id: i64, chunk_index: usize, doc_hash: &str, chunk_hash: &str, model_name: &str, error: &str, ) -> Result<()> { let now = chrono::Utc::now().timestamp_millis(); conn.execute( "INSERT INTO embedding_metadata (document_id, chunk_index, model, dims, document_hash, chunk_hash, created_at, attempt_count, last_error, last_attempt_at, chunk_max_bytes) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, 1, ?8, ?7, ?9) ON CONFLICT(document_id, chunk_index) DO UPDATE SET attempt_count = embedding_metadata.attempt_count + 1, last_error = ?8, last_attempt_at = ?7, chunk_max_bytes = ?9", rusqlite::params![ doc_id, chunk_index as i64, model_name, EXPECTED_DIMS as i64, doc_hash, chunk_hash, now, error, CHUNK_MAX_BYTES as i64 ], )?; Ok(()) } fn sha256_hash(input: &str) -> String { let mut hasher = Sha256::new(); hasher.update(input.as_bytes()); format!("{:x}", hasher.finalize()) }