feat(embed): concurrent batching, UTF-8 safe chunking, right-sized chunks

Three fixes to the embedding pipeline:

1. Concurrent HTTP batching: fire EMBED_CONCURRENCY (2) Ollama requests
   in parallel via join_all, then write results serially to SQLite.
   ~2x throughput improvement on GPU-bound workloads.

2. UTF-8 boundary safety: all computed byte offsets in split_into_chunks
   (paragraph/sentence/word break finders + overlap advance) now use
   floor_char_boundary() to prevent panics on multi-byte characters
   like smart quotes and non-breaking spaces.

3. CHUNK_MAX_BYTES reduced from 6000 to 1500 to fit nomic-embed-text's
   actual 2048-token context window, eliminating context-length retry
   storms that were causing 10x slowdowns.

Also threads ShutdownSignal through embed pipeline for graceful Ctrl+C.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Taylor Eernisse
2026-02-06 14:48:34 -05:00
parent 1c45725cba
commit 39cb0cb087
5 changed files with 199 additions and 115 deletions

View File

@@ -5,6 +5,7 @@ use crate::Config;
use crate::core::db::create_connection; use crate::core::db::create_connection;
use crate::core::error::Result; use crate::core::error::Result;
use crate::core::paths::get_db_path; use crate::core::paths::get_db_path;
use crate::core::shutdown::ShutdownSignal;
use crate::embedding::ollama::{OllamaClient, OllamaConfig}; use crate::embedding::ollama::{OllamaClient, OllamaConfig};
use crate::embedding::pipeline::embed_documents; use crate::embedding::pipeline::embed_documents;
@@ -20,6 +21,7 @@ pub async fn run_embed(
full: bool, full: bool,
retry_failed: bool, retry_failed: bool,
progress_callback: Option<Box<dyn Fn(usize, usize)>>, progress_callback: Option<Box<dyn Fn(usize, usize)>>,
signal: &ShutdownSignal,
) -> Result<EmbedCommandResult> { ) -> Result<EmbedCommandResult> {
let db_path = get_db_path(config.storage.db_path.as_deref()); let db_path = get_db_path(config.storage.db_path.as_deref());
let conn = create_connection(&db_path)?; let conn = create_connection(&db_path)?;
@@ -49,7 +51,7 @@ pub async fn run_embed(
} }
let model_name = &config.embedding.model; let model_name = &config.embedding.model;
let result = embed_documents(&conn, &client, model_name, progress_callback).await?; let result = embed_documents(&conn, &client, model_name, progress_callback, signal).await?;
Ok(EmbedCommandResult { Ok(EmbedCommandResult {
embedded: result.embedded, embedded: result.embedded,

View File

@@ -239,7 +239,7 @@ pub async fn run_sync(
embed_bar_clone.set_position(processed as u64); embed_bar_clone.set_position(processed as u64);
} }
}); });
match run_embed(config, options.full, false, Some(embed_cb)).await { match run_embed(config, options.full, false, Some(embed_cb), signal).await {
Ok(embed_result) => { Ok(embed_result) => {
result.documents_embedded = embed_result.embedded; result.documents_embedded = embed_result.embedded;
embed_bar.finish_and_clear(); embed_bar.finish_and_clear();

View File

@@ -1,4 +1,4 @@
pub const CHUNK_MAX_BYTES: usize = 6_000; pub const CHUNK_MAX_BYTES: usize = 1_500;
pub const EXPECTED_DIMS: usize = 768; pub const EXPECTED_DIMS: usize = 768;
@@ -42,6 +42,8 @@ pub fn split_into_chunks(content: &str) -> Vec<(usize, String)> {
} }
.max(1); .max(1);
start += advance; start += advance;
// Ensure start lands on a char boundary after overlap subtraction
start = floor_char_boundary(content, start);
chunk_index += 1; chunk_index += 1;
} }
@@ -49,7 +51,7 @@ pub fn split_into_chunks(content: &str) -> Vec<(usize, String)> {
} }
fn find_paragraph_break(window: &str) -> Option<usize> { fn find_paragraph_break(window: &str) -> Option<usize> {
let search_start = window.len() * 2 / 3; let search_start = floor_char_boundary(window, window.len() * 2 / 3);
window[search_start..] window[search_start..]
.rfind("\n\n") .rfind("\n\n")
.map(|pos| search_start + pos + 2) .map(|pos| search_start + pos + 2)
@@ -57,7 +59,7 @@ fn find_paragraph_break(window: &str) -> Option<usize> {
} }
fn find_sentence_break(window: &str) -> Option<usize> { fn find_sentence_break(window: &str) -> Option<usize> {
let search_start = window.len() / 2; let search_start = floor_char_boundary(window, window.len() / 2);
for pat in &[". ", "? ", "! "] { for pat in &[". ", "? ", "! "] {
if let Some(pos) = window[search_start..].rfind(pat) { if let Some(pos) = window[search_start..].rfind(pat) {
return Some(search_start + pos + pat.len()); return Some(search_start + pos + pat.len());
@@ -72,7 +74,7 @@ fn find_sentence_break(window: &str) -> Option<usize> {
} }
fn find_word_break(window: &str) -> Option<usize> { fn find_word_break(window: &str) -> Option<usize> {
let search_start = window.len() / 2; let search_start = floor_char_boundary(window, window.len() / 2);
window[search_start..] window[search_start..]
.rfind(' ') .rfind(' ')
.map(|pos| search_start + pos + 1) .map(|pos| search_start + pos + 1)
@@ -180,4 +182,41 @@ mod tests {
assert_eq!(*idx, i, "Chunk index mismatch at position {}", i); assert_eq!(*idx, i, "Chunk index mismatch at position {}", i);
} }
} }
#[test]
fn test_multibyte_characters_no_panic() {
// Build content with multi-byte UTF-8 chars (smart quotes, emoji, CJK)
// placed at positions likely to hit len()*2/3 and len()/2 boundaries
let segment = "We\u{2019}ve gradually ar\u{2014}ranged the components. ";
let mut content = String::new();
while content.len() < CHUNK_MAX_BYTES * 3 {
content.push_str(segment);
}
// Should not panic on multi-byte boundary
let chunks = split_into_chunks(&content);
assert!(chunks.len() >= 2);
for (_, chunk) in &chunks {
assert!(!chunk.is_empty());
}
}
#[test]
fn test_nbsp_at_overlap_boundary() {
// Reproduce the exact crash: \u{a0} (non-breaking space, 2-byte UTF-8)
// placed so that split_at - CHUNK_OVERLAP_CHARS lands mid-character
let mut content = String::new();
// Fill with ASCII up to near CHUNK_MAX_BYTES, then place \u{a0}
// near where the overlap subtraction would land
let target = CHUNK_MAX_BYTES - CHUNK_OVERLAP_CHARS;
while content.len() < target - 2 {
content.push('a');
}
content.push('\u{a0}'); // 2-byte char right at the overlap boundary
while content.len() < CHUNK_MAX_BYTES * 3 {
content.push('b');
}
// Should not panic
let chunks = split_into_chunks(&content);
assert!(chunks.len() >= 2);
}
} }

View File

@@ -1,10 +1,12 @@
use std::collections::HashSet; use std::collections::HashSet;
use futures::future::join_all;
use rusqlite::Connection; use rusqlite::Connection;
use sha2::{Digest, Sha256}; use sha2::{Digest, Sha256};
use tracing::{info, instrument, warn}; use tracing::{info, instrument, warn};
use crate::core::error::Result; use crate::core::error::Result;
use crate::core::shutdown::ShutdownSignal;
use crate::embedding::change_detector::{count_pending_documents, find_pending_documents}; use crate::embedding::change_detector::{count_pending_documents, find_pending_documents};
use crate::embedding::chunk_ids::{CHUNK_ROWID_MULTIPLIER, encode_rowid}; use crate::embedding::chunk_ids::{CHUNK_ROWID_MULTIPLIER, encode_rowid};
use crate::embedding::chunking::{CHUNK_MAX_BYTES, EXPECTED_DIMS, split_into_chunks}; use crate::embedding::chunking::{CHUNK_MAX_BYTES, EXPECTED_DIMS, split_into_chunks};
@@ -12,6 +14,7 @@ use crate::embedding::ollama::OllamaClient;
const BATCH_SIZE: usize = 32; const BATCH_SIZE: usize = 32;
const DB_PAGE_SIZE: usize = 500; const DB_PAGE_SIZE: usize = 500;
const EMBED_CONCURRENCY: usize = 2;
#[derive(Debug, Default)] #[derive(Debug, Default)]
pub struct EmbedResult { pub struct EmbedResult {
@@ -29,12 +32,13 @@ struct ChunkWork {
text: String, text: String,
} }
#[instrument(skip(conn, client, progress_callback), fields(%model_name, items_processed, items_skipped, errors))] #[instrument(skip(conn, client, progress_callback, signal), fields(%model_name, items_processed, items_skipped, errors))]
pub async fn embed_documents( pub async fn embed_documents(
conn: &Connection, conn: &Connection,
client: &OllamaClient, client: &OllamaClient,
model_name: &str, model_name: &str,
progress_callback: Option<Box<dyn Fn(usize, usize)>>, progress_callback: Option<Box<dyn Fn(usize, usize)>>,
signal: &ShutdownSignal,
) -> Result<EmbedResult> { ) -> Result<EmbedResult> {
let total = count_pending_documents(conn, model_name)? as usize; let total = count_pending_documents(conn, model_name)? as usize;
let mut result = EmbedResult::default(); let mut result = EmbedResult::default();
@@ -48,6 +52,11 @@ pub async fn embed_documents(
info!(total, "Starting embedding pipeline"); info!(total, "Starting embedding pipeline");
loop { loop {
if signal.is_cancelled() {
info!("Shutdown requested, stopping embedding pipeline");
break;
}
let pending = find_pending_documents(conn, DB_PAGE_SIZE, last_id, model_name)?; let pending = find_pending_documents(conn, DB_PAGE_SIZE, last_id, model_name)?;
if pending.is_empty() { if pending.is_empty() {
break; break;
@@ -64,6 +73,7 @@ pub async fn embed_documents(
&mut processed, &mut processed,
total, total,
&progress_callback, &progress_callback,
signal,
) )
.await; .await;
match page_result { match page_result {
@@ -102,6 +112,7 @@ async fn embed_page(
processed: &mut usize, processed: &mut usize,
total: usize, total: usize,
progress_callback: &Option<Box<dyn Fn(usize, usize)>>, progress_callback: &Option<Box<dyn Fn(usize, usize)>>,
signal: &ShutdownSignal,
) -> Result<()> { ) -> Result<()> {
let mut all_chunks: Vec<ChunkWork> = Vec::with_capacity(pending.len() * 3); let mut all_chunks: Vec<ChunkWork> = Vec::with_capacity(pending.len() * 3);
let mut page_normal_docs: usize = 0; let mut page_normal_docs: usize = 0;
@@ -161,10 +172,31 @@ async fn embed_page(
let mut cleared_docs: HashSet<i64> = HashSet::with_capacity(pending.len()); let mut cleared_docs: HashSet<i64> = HashSet::with_capacity(pending.len());
for batch in all_chunks.chunks(BATCH_SIZE) { // Split chunks into batches, then process batches in concurrent groups
let texts: Vec<&str> = batch.iter().map(|c| c.text.as_str()).collect(); let batches: Vec<&[ChunkWork]> = all_chunks.chunks(BATCH_SIZE).collect();
match client.embed_batch(&texts).await { for concurrent_group in batches.chunks(EMBED_CONCURRENCY) {
if signal.is_cancelled() {
info!("Shutdown requested during embedding, stopping mid-page");
break;
}
// Phase 1: Collect texts (must outlive the futures)
let batch_texts: Vec<Vec<&str>> = concurrent_group
.iter()
.map(|batch| batch.iter().map(|c| c.text.as_str()).collect())
.collect();
// Phase 2: Fire concurrent HTTP requests to Ollama
let futures: Vec<_> = batch_texts
.iter()
.map(|texts| client.embed_batch(texts))
.collect();
let api_results = join_all(futures).await;
// Phase 3: Serial DB writes for each batch result
for (batch, api_result) in concurrent_group.iter().zip(api_results) {
match api_result {
Ok(embeddings) => { Ok(embeddings) => {
for (i, embedding) in embeddings.iter().enumerate() { for (i, embedding) in embeddings.iter().enumerate() {
if i >= batch.len() { if i >= batch.len() {
@@ -226,8 +258,10 @@ async fn embed_page(
|| (err_lower.contains("413") && err_lower.contains("http")); || (err_lower.contains("413") && err_lower.contains("http"));
if is_context_error && batch.len() > 1 { if is_context_error && batch.len() > 1 {
warn!("Batch failed with context length error, retrying chunks individually"); warn!(
for chunk in batch { "Batch failed with context length error, retrying chunks individually"
);
for chunk in *batch {
match client.embed_batch(&[chunk.text.as_str()]).await { match client.embed_batch(&[chunk.text.as_str()]).await {
Ok(embeddings) Ok(embeddings)
if !embeddings.is_empty() if !embeddings.is_empty()
@@ -272,7 +306,7 @@ async fn embed_page(
} }
} else { } else {
warn!(error = %e, "Batch embedding failed"); warn!(error = %e, "Batch embedding failed");
for chunk in batch { for chunk in *batch {
record_embedding_error( record_embedding_error(
conn, conn,
chunk.doc_id, chunk.doc_id,
@@ -288,6 +322,7 @@ async fn embed_page(
} }
} }
} }
}
*processed += page_normal_docs; *processed += page_normal_docs;
if let Some(cb) = progress_callback { if let Some(cb) = progress_callback {

View File

@@ -1517,7 +1517,15 @@ async fn handle_embed(
let config = Config::load(config_override)?; let config = Config::load(config_override)?;
let full = args.full && !args.no_full; let full = args.full && !args.no_full;
let retry_failed = args.retry_failed && !args.no_retry_failed; let retry_failed = args.retry_failed && !args.no_retry_failed;
let result = run_embed(&config, full, retry_failed, None).await?;
let signal = ShutdownSignal::new();
let signal_for_handler = signal.clone();
tokio::spawn(async move {
let _ = tokio::signal::ctrl_c().await;
signal_for_handler.cancel();
});
let result = run_embed(&config, full, retry_failed, None, &signal).await?;
if robot_mode { if robot_mode {
print_embed_json(&result); print_embed_json(&result);
} else { } else {