feat(embed): concurrent batching, UTF-8 safe chunking, right-sized chunks
Three fixes to the embedding pipeline: 1. Concurrent HTTP batching: fire EMBED_CONCURRENCY (2) Ollama requests in parallel via join_all, then write results serially to SQLite. ~2x throughput improvement on GPU-bound workloads. 2. UTF-8 boundary safety: all computed byte offsets in split_into_chunks (paragraph/sentence/word break finders + overlap advance) now use floor_char_boundary() to prevent panics on multi-byte characters like smart quotes and non-breaking spaces. 3. CHUNK_MAX_BYTES reduced from 6000 to 1500 to fit nomic-embed-text's actual 2048-token context window, eliminating context-length retry storms that were causing 10x slowdowns. Also threads ShutdownSignal through embed pipeline for graceful Ctrl+C. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -1,10 +1,12 @@
|
||||
use std::collections::HashSet;
|
||||
|
||||
use futures::future::join_all;
|
||||
use rusqlite::Connection;
|
||||
use sha2::{Digest, Sha256};
|
||||
use tracing::{info, instrument, warn};
|
||||
|
||||
use crate::core::error::Result;
|
||||
use crate::core::shutdown::ShutdownSignal;
|
||||
use crate::embedding::change_detector::{count_pending_documents, find_pending_documents};
|
||||
use crate::embedding::chunk_ids::{CHUNK_ROWID_MULTIPLIER, encode_rowid};
|
||||
use crate::embedding::chunking::{CHUNK_MAX_BYTES, EXPECTED_DIMS, split_into_chunks};
|
||||
@@ -12,6 +14,7 @@ use crate::embedding::ollama::OllamaClient;
|
||||
|
||||
const BATCH_SIZE: usize = 32;
|
||||
const DB_PAGE_SIZE: usize = 500;
|
||||
const EMBED_CONCURRENCY: usize = 2;
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct EmbedResult {
|
||||
@@ -29,12 +32,13 @@ struct ChunkWork {
|
||||
text: String,
|
||||
}
|
||||
|
||||
#[instrument(skip(conn, client, progress_callback), fields(%model_name, items_processed, items_skipped, errors))]
|
||||
#[instrument(skip(conn, client, progress_callback, signal), fields(%model_name, items_processed, items_skipped, errors))]
|
||||
pub async fn embed_documents(
|
||||
conn: &Connection,
|
||||
client: &OllamaClient,
|
||||
model_name: &str,
|
||||
progress_callback: Option<Box<dyn Fn(usize, usize)>>,
|
||||
signal: &ShutdownSignal,
|
||||
) -> Result<EmbedResult> {
|
||||
let total = count_pending_documents(conn, model_name)? as usize;
|
||||
let mut result = EmbedResult::default();
|
||||
@@ -48,6 +52,11 @@ pub async fn embed_documents(
|
||||
info!(total, "Starting embedding pipeline");
|
||||
|
||||
loop {
|
||||
if signal.is_cancelled() {
|
||||
info!("Shutdown requested, stopping embedding pipeline");
|
||||
break;
|
||||
}
|
||||
|
||||
let pending = find_pending_documents(conn, DB_PAGE_SIZE, last_id, model_name)?;
|
||||
if pending.is_empty() {
|
||||
break;
|
||||
@@ -64,6 +73,7 @@ pub async fn embed_documents(
|
||||
&mut processed,
|
||||
total,
|
||||
&progress_callback,
|
||||
signal,
|
||||
)
|
||||
.await;
|
||||
match page_result {
|
||||
@@ -102,6 +112,7 @@ async fn embed_page(
|
||||
processed: &mut usize,
|
||||
total: usize,
|
||||
progress_callback: &Option<Box<dyn Fn(usize, usize)>>,
|
||||
signal: &ShutdownSignal,
|
||||
) -> Result<()> {
|
||||
let mut all_chunks: Vec<ChunkWork> = Vec::with_capacity(pending.len() * 3);
|
||||
let mut page_normal_docs: usize = 0;
|
||||
@@ -161,128 +172,152 @@ async fn embed_page(
|
||||
|
||||
let mut cleared_docs: HashSet<i64> = HashSet::with_capacity(pending.len());
|
||||
|
||||
for batch in all_chunks.chunks(BATCH_SIZE) {
|
||||
let texts: Vec<&str> = batch.iter().map(|c| c.text.as_str()).collect();
|
||||
// Split chunks into batches, then process batches in concurrent groups
|
||||
let batches: Vec<&[ChunkWork]> = all_chunks.chunks(BATCH_SIZE).collect();
|
||||
|
||||
match client.embed_batch(&texts).await {
|
||||
Ok(embeddings) => {
|
||||
for (i, embedding) in embeddings.iter().enumerate() {
|
||||
if i >= batch.len() {
|
||||
break;
|
||||
}
|
||||
let chunk = &batch[i];
|
||||
for concurrent_group in batches.chunks(EMBED_CONCURRENCY) {
|
||||
if signal.is_cancelled() {
|
||||
info!("Shutdown requested during embedding, stopping mid-page");
|
||||
break;
|
||||
}
|
||||
|
||||
if embedding.len() != EXPECTED_DIMS {
|
||||
warn!(
|
||||
doc_id = chunk.doc_id,
|
||||
chunk_index = chunk.chunk_index,
|
||||
got_dims = embedding.len(),
|
||||
expected = EXPECTED_DIMS,
|
||||
"Dimension mismatch, skipping"
|
||||
);
|
||||
record_embedding_error(
|
||||
// Phase 1: Collect texts (must outlive the futures)
|
||||
let batch_texts: Vec<Vec<&str>> = concurrent_group
|
||||
.iter()
|
||||
.map(|batch| batch.iter().map(|c| c.text.as_str()).collect())
|
||||
.collect();
|
||||
|
||||
// Phase 2: Fire concurrent HTTP requests to Ollama
|
||||
let futures: Vec<_> = batch_texts
|
||||
.iter()
|
||||
.map(|texts| client.embed_batch(texts))
|
||||
.collect();
|
||||
let api_results = join_all(futures).await;
|
||||
|
||||
// Phase 3: Serial DB writes for each batch result
|
||||
for (batch, api_result) in concurrent_group.iter().zip(api_results) {
|
||||
match api_result {
|
||||
Ok(embeddings) => {
|
||||
for (i, embedding) in embeddings.iter().enumerate() {
|
||||
if i >= batch.len() {
|
||||
break;
|
||||
}
|
||||
let chunk = &batch[i];
|
||||
|
||||
if embedding.len() != EXPECTED_DIMS {
|
||||
warn!(
|
||||
doc_id = chunk.doc_id,
|
||||
chunk_index = chunk.chunk_index,
|
||||
got_dims = embedding.len(),
|
||||
expected = EXPECTED_DIMS,
|
||||
"Dimension mismatch, skipping"
|
||||
);
|
||||
record_embedding_error(
|
||||
conn,
|
||||
chunk.doc_id,
|
||||
chunk.chunk_index,
|
||||
&chunk.doc_hash,
|
||||
&chunk.chunk_hash,
|
||||
model_name,
|
||||
&format!(
|
||||
"Dimension mismatch: got {}, expected {}",
|
||||
embedding.len(),
|
||||
EXPECTED_DIMS
|
||||
),
|
||||
)?;
|
||||
result.failed += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
if !cleared_docs.contains(&chunk.doc_id) {
|
||||
clear_document_embeddings(conn, chunk.doc_id)?;
|
||||
cleared_docs.insert(chunk.doc_id);
|
||||
}
|
||||
|
||||
store_embedding(
|
||||
conn,
|
||||
chunk.doc_id,
|
||||
chunk.chunk_index,
|
||||
&chunk.doc_hash,
|
||||
&chunk.chunk_hash,
|
||||
model_name,
|
||||
&format!(
|
||||
"Dimension mismatch: got {}, expected {}",
|
||||
embedding.len(),
|
||||
EXPECTED_DIMS
|
||||
),
|
||||
embedding,
|
||||
chunk.total_chunks,
|
||||
)?;
|
||||
result.failed += 1;
|
||||
continue;
|
||||
result.embedded += 1;
|
||||
}
|
||||
|
||||
if !cleared_docs.contains(&chunk.doc_id) {
|
||||
clear_document_embeddings(conn, chunk.doc_id)?;
|
||||
cleared_docs.insert(chunk.doc_id);
|
||||
}
|
||||
|
||||
store_embedding(
|
||||
conn,
|
||||
chunk.doc_id,
|
||||
chunk.chunk_index,
|
||||
&chunk.doc_hash,
|
||||
&chunk.chunk_hash,
|
||||
model_name,
|
||||
embedding,
|
||||
chunk.total_chunks,
|
||||
)?;
|
||||
result.embedded += 1;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
let err_str = e.to_string();
|
||||
let err_lower = err_str.to_lowercase();
|
||||
let is_context_error = err_lower.contains("context length")
|
||||
|| err_lower.contains("too long")
|
||||
|| err_lower.contains("maximum context")
|
||||
|| err_lower.contains("token limit")
|
||||
|| err_lower.contains("exceeds")
|
||||
|| (err_lower.contains("413") && err_lower.contains("http"));
|
||||
Err(e) => {
|
||||
let err_str = e.to_string();
|
||||
let err_lower = err_str.to_lowercase();
|
||||
let is_context_error = err_lower.contains("context length")
|
||||
|| err_lower.contains("too long")
|
||||
|| err_lower.contains("maximum context")
|
||||
|| err_lower.contains("token limit")
|
||||
|| err_lower.contains("exceeds")
|
||||
|| (err_lower.contains("413") && err_lower.contains("http"));
|
||||
|
||||
if is_context_error && batch.len() > 1 {
|
||||
warn!("Batch failed with context length error, retrying chunks individually");
|
||||
for chunk in batch {
|
||||
match client.embed_batch(&[chunk.text.as_str()]).await {
|
||||
Ok(embeddings)
|
||||
if !embeddings.is_empty()
|
||||
&& embeddings[0].len() == EXPECTED_DIMS =>
|
||||
{
|
||||
if !cleared_docs.contains(&chunk.doc_id) {
|
||||
clear_document_embeddings(conn, chunk.doc_id)?;
|
||||
cleared_docs.insert(chunk.doc_id);
|
||||
if is_context_error && batch.len() > 1 {
|
||||
warn!(
|
||||
"Batch failed with context length error, retrying chunks individually"
|
||||
);
|
||||
for chunk in *batch {
|
||||
match client.embed_batch(&[chunk.text.as_str()]).await {
|
||||
Ok(embeddings)
|
||||
if !embeddings.is_empty()
|
||||
&& embeddings[0].len() == EXPECTED_DIMS =>
|
||||
{
|
||||
if !cleared_docs.contains(&chunk.doc_id) {
|
||||
clear_document_embeddings(conn, chunk.doc_id)?;
|
||||
cleared_docs.insert(chunk.doc_id);
|
||||
}
|
||||
|
||||
store_embedding(
|
||||
conn,
|
||||
chunk.doc_id,
|
||||
chunk.chunk_index,
|
||||
&chunk.doc_hash,
|
||||
&chunk.chunk_hash,
|
||||
model_name,
|
||||
&embeddings[0],
|
||||
chunk.total_chunks,
|
||||
)?;
|
||||
result.embedded += 1;
|
||||
}
|
||||
_ => {
|
||||
warn!(
|
||||
doc_id = chunk.doc_id,
|
||||
chunk_index = chunk.chunk_index,
|
||||
chunk_bytes = chunk.text.len(),
|
||||
"Chunk too large for model context window"
|
||||
);
|
||||
record_embedding_error(
|
||||
conn,
|
||||
chunk.doc_id,
|
||||
chunk.chunk_index,
|
||||
&chunk.doc_hash,
|
||||
&chunk.chunk_hash,
|
||||
model_name,
|
||||
"Chunk exceeds model context window",
|
||||
)?;
|
||||
result.failed += 1;
|
||||
}
|
||||
|
||||
store_embedding(
|
||||
conn,
|
||||
chunk.doc_id,
|
||||
chunk.chunk_index,
|
||||
&chunk.doc_hash,
|
||||
&chunk.chunk_hash,
|
||||
model_name,
|
||||
&embeddings[0],
|
||||
chunk.total_chunks,
|
||||
)?;
|
||||
result.embedded += 1;
|
||||
}
|
||||
_ => {
|
||||
warn!(
|
||||
doc_id = chunk.doc_id,
|
||||
chunk_index = chunk.chunk_index,
|
||||
chunk_bytes = chunk.text.len(),
|
||||
"Chunk too large for model context window"
|
||||
);
|
||||
record_embedding_error(
|
||||
conn,
|
||||
chunk.doc_id,
|
||||
chunk.chunk_index,
|
||||
&chunk.doc_hash,
|
||||
&chunk.chunk_hash,
|
||||
model_name,
|
||||
"Chunk exceeds model context window",
|
||||
)?;
|
||||
result.failed += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
warn!(error = %e, "Batch embedding failed");
|
||||
for chunk in batch {
|
||||
record_embedding_error(
|
||||
conn,
|
||||
chunk.doc_id,
|
||||
chunk.chunk_index,
|
||||
&chunk.doc_hash,
|
||||
&chunk.chunk_hash,
|
||||
model_name,
|
||||
&e.to_string(),
|
||||
)?;
|
||||
result.failed += 1;
|
||||
} else {
|
||||
warn!(error = %e, "Batch embedding failed");
|
||||
for chunk in *batch {
|
||||
record_embedding_error(
|
||||
conn,
|
||||
chunk.doc_id,
|
||||
chunk.chunk_index,
|
||||
&chunk.doc_hash,
|
||||
&chunk.chunk_hash,
|
||||
model_name,
|
||||
&e.to_string(),
|
||||
)?;
|
||||
result.failed += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user