diff --git a/src/cli/commands/embed.rs b/src/cli/commands/embed.rs index 9f043e4..9fc0f3c 100644 --- a/src/cli/commands/embed.rs +++ b/src/cli/commands/embed.rs @@ -5,6 +5,7 @@ use crate::Config; use crate::core::db::create_connection; use crate::core::error::Result; use crate::core::paths::get_db_path; +use crate::core::shutdown::ShutdownSignal; use crate::embedding::ollama::{OllamaClient, OllamaConfig}; use crate::embedding::pipeline::embed_documents; @@ -20,6 +21,7 @@ pub async fn run_embed( full: bool, retry_failed: bool, progress_callback: Option>, + signal: &ShutdownSignal, ) -> Result { let db_path = get_db_path(config.storage.db_path.as_deref()); let conn = create_connection(&db_path)?; @@ -49,7 +51,7 @@ pub async fn run_embed( } let model_name = &config.embedding.model; - let result = embed_documents(&conn, &client, model_name, progress_callback).await?; + let result = embed_documents(&conn, &client, model_name, progress_callback, signal).await?; Ok(EmbedCommandResult { embedded: result.embedded, diff --git a/src/cli/commands/sync.rs b/src/cli/commands/sync.rs index 69eaa64..598ec74 100644 --- a/src/cli/commands/sync.rs +++ b/src/cli/commands/sync.rs @@ -239,7 +239,7 @@ pub async fn run_sync( embed_bar_clone.set_position(processed as u64); } }); - match run_embed(config, options.full, false, Some(embed_cb)).await { + match run_embed(config, options.full, false, Some(embed_cb), signal).await { Ok(embed_result) => { result.documents_embedded = embed_result.embedded; embed_bar.finish_and_clear(); diff --git a/src/embedding/chunking.rs b/src/embedding/chunking.rs index 36be003..b4aecae 100644 --- a/src/embedding/chunking.rs +++ b/src/embedding/chunking.rs @@ -1,4 +1,4 @@ -pub const CHUNK_MAX_BYTES: usize = 6_000; +pub const CHUNK_MAX_BYTES: usize = 1_500; pub const EXPECTED_DIMS: usize = 768; @@ -42,6 +42,8 @@ pub fn split_into_chunks(content: &str) -> Vec<(usize, String)> { } .max(1); start += advance; + // Ensure start lands on a char boundary after overlap subtraction + start = floor_char_boundary(content, start); chunk_index += 1; } @@ -49,7 +51,7 @@ pub fn split_into_chunks(content: &str) -> Vec<(usize, String)> { } fn find_paragraph_break(window: &str) -> Option { - let search_start = window.len() * 2 / 3; + let search_start = floor_char_boundary(window, window.len() * 2 / 3); window[search_start..] .rfind("\n\n") .map(|pos| search_start + pos + 2) @@ -57,7 +59,7 @@ fn find_paragraph_break(window: &str) -> Option { } fn find_sentence_break(window: &str) -> Option { - let search_start = window.len() / 2; + let search_start = floor_char_boundary(window, window.len() / 2); for pat in &[". ", "? ", "! "] { if let Some(pos) = window[search_start..].rfind(pat) { return Some(search_start + pos + pat.len()); @@ -72,7 +74,7 @@ fn find_sentence_break(window: &str) -> Option { } fn find_word_break(window: &str) -> Option { - let search_start = window.len() / 2; + let search_start = floor_char_boundary(window, window.len() / 2); window[search_start..] .rfind(' ') .map(|pos| search_start + pos + 1) @@ -180,4 +182,41 @@ mod tests { assert_eq!(*idx, i, "Chunk index mismatch at position {}", i); } } + + #[test] + fn test_multibyte_characters_no_panic() { + // Build content with multi-byte UTF-8 chars (smart quotes, emoji, CJK) + // placed at positions likely to hit len()*2/3 and len()/2 boundaries + let segment = "We\u{2019}ve gradually ar\u{2014}ranged the components. "; + let mut content = String::new(); + while content.len() < CHUNK_MAX_BYTES * 3 { + content.push_str(segment); + } + // Should not panic on multi-byte boundary + let chunks = split_into_chunks(&content); + assert!(chunks.len() >= 2); + for (_, chunk) in &chunks { + assert!(!chunk.is_empty()); + } + } + + #[test] + fn test_nbsp_at_overlap_boundary() { + // Reproduce the exact crash: \u{a0} (non-breaking space, 2-byte UTF-8) + // placed so that split_at - CHUNK_OVERLAP_CHARS lands mid-character + let mut content = String::new(); + // Fill with ASCII up to near CHUNK_MAX_BYTES, then place \u{a0} + // near where the overlap subtraction would land + let target = CHUNK_MAX_BYTES - CHUNK_OVERLAP_CHARS; + while content.len() < target - 2 { + content.push('a'); + } + content.push('\u{a0}'); // 2-byte char right at the overlap boundary + while content.len() < CHUNK_MAX_BYTES * 3 { + content.push('b'); + } + // Should not panic + let chunks = split_into_chunks(&content); + assert!(chunks.len() >= 2); + } } diff --git a/src/embedding/pipeline.rs b/src/embedding/pipeline.rs index eb4e440..1281044 100644 --- a/src/embedding/pipeline.rs +++ b/src/embedding/pipeline.rs @@ -1,10 +1,12 @@ use std::collections::HashSet; +use futures::future::join_all; use rusqlite::Connection; use sha2::{Digest, Sha256}; use tracing::{info, instrument, warn}; use crate::core::error::Result; +use crate::core::shutdown::ShutdownSignal; use crate::embedding::change_detector::{count_pending_documents, find_pending_documents}; use crate::embedding::chunk_ids::{CHUNK_ROWID_MULTIPLIER, encode_rowid}; use crate::embedding::chunking::{CHUNK_MAX_BYTES, EXPECTED_DIMS, split_into_chunks}; @@ -12,6 +14,7 @@ use crate::embedding::ollama::OllamaClient; const BATCH_SIZE: usize = 32; const DB_PAGE_SIZE: usize = 500; +const EMBED_CONCURRENCY: usize = 2; #[derive(Debug, Default)] pub struct EmbedResult { @@ -29,12 +32,13 @@ struct ChunkWork { text: String, } -#[instrument(skip(conn, client, progress_callback), fields(%model_name, items_processed, items_skipped, errors))] +#[instrument(skip(conn, client, progress_callback, signal), fields(%model_name, items_processed, items_skipped, errors))] pub async fn embed_documents( conn: &Connection, client: &OllamaClient, model_name: &str, progress_callback: Option>, + signal: &ShutdownSignal, ) -> Result { let total = count_pending_documents(conn, model_name)? as usize; let mut result = EmbedResult::default(); @@ -48,6 +52,11 @@ pub async fn embed_documents( info!(total, "Starting embedding pipeline"); loop { + if signal.is_cancelled() { + info!("Shutdown requested, stopping embedding pipeline"); + break; + } + let pending = find_pending_documents(conn, DB_PAGE_SIZE, last_id, model_name)?; if pending.is_empty() { break; @@ -64,6 +73,7 @@ pub async fn embed_documents( &mut processed, total, &progress_callback, + signal, ) .await; match page_result { @@ -102,6 +112,7 @@ async fn embed_page( processed: &mut usize, total: usize, progress_callback: &Option>, + signal: &ShutdownSignal, ) -> Result<()> { let mut all_chunks: Vec = Vec::with_capacity(pending.len() * 3); let mut page_normal_docs: usize = 0; @@ -161,128 +172,152 @@ async fn embed_page( let mut cleared_docs: HashSet = HashSet::with_capacity(pending.len()); - for batch in all_chunks.chunks(BATCH_SIZE) { - let texts: Vec<&str> = batch.iter().map(|c| c.text.as_str()).collect(); + // Split chunks into batches, then process batches in concurrent groups + let batches: Vec<&[ChunkWork]> = all_chunks.chunks(BATCH_SIZE).collect(); - match client.embed_batch(&texts).await { - Ok(embeddings) => { - for (i, embedding) in embeddings.iter().enumerate() { - if i >= batch.len() { - break; - } - let chunk = &batch[i]; + for concurrent_group in batches.chunks(EMBED_CONCURRENCY) { + if signal.is_cancelled() { + info!("Shutdown requested during embedding, stopping mid-page"); + break; + } - if embedding.len() != EXPECTED_DIMS { - warn!( - doc_id = chunk.doc_id, - chunk_index = chunk.chunk_index, - got_dims = embedding.len(), - expected = EXPECTED_DIMS, - "Dimension mismatch, skipping" - ); - record_embedding_error( + // Phase 1: Collect texts (must outlive the futures) + let batch_texts: Vec> = concurrent_group + .iter() + .map(|batch| batch.iter().map(|c| c.text.as_str()).collect()) + .collect(); + + // Phase 2: Fire concurrent HTTP requests to Ollama + let futures: Vec<_> = batch_texts + .iter() + .map(|texts| client.embed_batch(texts)) + .collect(); + let api_results = join_all(futures).await; + + // Phase 3: Serial DB writes for each batch result + for (batch, api_result) in concurrent_group.iter().zip(api_results) { + match api_result { + Ok(embeddings) => { + for (i, embedding) in embeddings.iter().enumerate() { + if i >= batch.len() { + break; + } + let chunk = &batch[i]; + + if embedding.len() != EXPECTED_DIMS { + warn!( + doc_id = chunk.doc_id, + chunk_index = chunk.chunk_index, + got_dims = embedding.len(), + expected = EXPECTED_DIMS, + "Dimension mismatch, skipping" + ); + record_embedding_error( + conn, + chunk.doc_id, + chunk.chunk_index, + &chunk.doc_hash, + &chunk.chunk_hash, + model_name, + &format!( + "Dimension mismatch: got {}, expected {}", + embedding.len(), + EXPECTED_DIMS + ), + )?; + result.failed += 1; + continue; + } + + if !cleared_docs.contains(&chunk.doc_id) { + clear_document_embeddings(conn, chunk.doc_id)?; + cleared_docs.insert(chunk.doc_id); + } + + store_embedding( conn, chunk.doc_id, chunk.chunk_index, &chunk.doc_hash, &chunk.chunk_hash, model_name, - &format!( - "Dimension mismatch: got {}, expected {}", - embedding.len(), - EXPECTED_DIMS - ), + embedding, + chunk.total_chunks, )?; - result.failed += 1; - continue; + result.embedded += 1; } - - if !cleared_docs.contains(&chunk.doc_id) { - clear_document_embeddings(conn, chunk.doc_id)?; - cleared_docs.insert(chunk.doc_id); - } - - store_embedding( - conn, - chunk.doc_id, - chunk.chunk_index, - &chunk.doc_hash, - &chunk.chunk_hash, - model_name, - embedding, - chunk.total_chunks, - )?; - result.embedded += 1; } - } - Err(e) => { - let err_str = e.to_string(); - let err_lower = err_str.to_lowercase(); - let is_context_error = err_lower.contains("context length") - || err_lower.contains("too long") - || err_lower.contains("maximum context") - || err_lower.contains("token limit") - || err_lower.contains("exceeds") - || (err_lower.contains("413") && err_lower.contains("http")); + Err(e) => { + let err_str = e.to_string(); + let err_lower = err_str.to_lowercase(); + let is_context_error = err_lower.contains("context length") + || err_lower.contains("too long") + || err_lower.contains("maximum context") + || err_lower.contains("token limit") + || err_lower.contains("exceeds") + || (err_lower.contains("413") && err_lower.contains("http")); - if is_context_error && batch.len() > 1 { - warn!("Batch failed with context length error, retrying chunks individually"); - for chunk in batch { - match client.embed_batch(&[chunk.text.as_str()]).await { - Ok(embeddings) - if !embeddings.is_empty() - && embeddings[0].len() == EXPECTED_DIMS => - { - if !cleared_docs.contains(&chunk.doc_id) { - clear_document_embeddings(conn, chunk.doc_id)?; - cleared_docs.insert(chunk.doc_id); + if is_context_error && batch.len() > 1 { + warn!( + "Batch failed with context length error, retrying chunks individually" + ); + for chunk in *batch { + match client.embed_batch(&[chunk.text.as_str()]).await { + Ok(embeddings) + if !embeddings.is_empty() + && embeddings[0].len() == EXPECTED_DIMS => + { + if !cleared_docs.contains(&chunk.doc_id) { + clear_document_embeddings(conn, chunk.doc_id)?; + cleared_docs.insert(chunk.doc_id); + } + + store_embedding( + conn, + chunk.doc_id, + chunk.chunk_index, + &chunk.doc_hash, + &chunk.chunk_hash, + model_name, + &embeddings[0], + chunk.total_chunks, + )?; + result.embedded += 1; + } + _ => { + warn!( + doc_id = chunk.doc_id, + chunk_index = chunk.chunk_index, + chunk_bytes = chunk.text.len(), + "Chunk too large for model context window" + ); + record_embedding_error( + conn, + chunk.doc_id, + chunk.chunk_index, + &chunk.doc_hash, + &chunk.chunk_hash, + model_name, + "Chunk exceeds model context window", + )?; + result.failed += 1; } - - store_embedding( - conn, - chunk.doc_id, - chunk.chunk_index, - &chunk.doc_hash, - &chunk.chunk_hash, - model_name, - &embeddings[0], - chunk.total_chunks, - )?; - result.embedded += 1; - } - _ => { - warn!( - doc_id = chunk.doc_id, - chunk_index = chunk.chunk_index, - chunk_bytes = chunk.text.len(), - "Chunk too large for model context window" - ); - record_embedding_error( - conn, - chunk.doc_id, - chunk.chunk_index, - &chunk.doc_hash, - &chunk.chunk_hash, - model_name, - "Chunk exceeds model context window", - )?; - result.failed += 1; } } - } - } else { - warn!(error = %e, "Batch embedding failed"); - for chunk in batch { - record_embedding_error( - conn, - chunk.doc_id, - chunk.chunk_index, - &chunk.doc_hash, - &chunk.chunk_hash, - model_name, - &e.to_string(), - )?; - result.failed += 1; + } else { + warn!(error = %e, "Batch embedding failed"); + for chunk in *batch { + record_embedding_error( + conn, + chunk.doc_id, + chunk.chunk_index, + &chunk.doc_hash, + &chunk.chunk_hash, + model_name, + &e.to_string(), + )?; + result.failed += 1; + } } } } diff --git a/src/main.rs b/src/main.rs index 8d7ed9c..57608da 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1517,7 +1517,15 @@ async fn handle_embed( let config = Config::load(config_override)?; let full = args.full && !args.no_full; let retry_failed = args.retry_failed && !args.no_retry_failed; - let result = run_embed(&config, full, retry_failed, None).await?; + + let signal = ShutdownSignal::new(); + let signal_for_handler = signal.clone(); + tokio::spawn(async move { + let _ = tokio::signal::ctrl_c().await; + signal_for_handler.cancel(); + }); + + let result = run_embed(&config, full, retry_failed, None, &signal).await?; if robot_mode { print_embed_json(&result); } else {