Files
gitlore/src/embedding/pipeline.rs
Taylor Eernisse c2036c64e9 feat(embed): docs_embedded tracking, buffer reuse, retry hardening
Embedding pipeline improvements building on the concurrent batching
foundation:

- Track docs_embedded vs chunks_embedded separately. A document counts
  as embedded only when ALL its chunks succeed, giving accurate
  progress reporting. The sync command reads docs_embedded for its
  document count.

- Reuse a single Vec<u8> buffer (embed_buf) across all store_embedding
  calls instead of allocating per chunk. Eliminates ~3KB allocation per
  768-dim embedding.

- Detect and record errors when Ollama silently returns fewer
  embeddings than inputs (batch mismatch). Previously these dropped
  chunks were invisible.

- Improve retry error messages: distinguish "retry returned unexpected
  result" (wrong dims/count) from "retry request failed" (network
  error) instead of generic "chunk too large" message.

- Convert all hot-path SQL from conn.execute() to prepare_cached() for
  statement cache reuse (clear_document_embeddings, store_embedding,
  record_embedding_error).

- Record embedding_metadata errors for empty documents so they don't
  appear as perpetually pending on subsequent runs.

- Accept concurrency parameter (configurable via config.embedding.concurrency)
  instead of hardcoded EMBED_CONCURRENCY=2.

- Add schema version pre-flight check in embed command to fail fast
  with actionable error instead of cryptic SQL errors.

- Fix --retry-failed to use DELETE instead of UPDATE. UPDATE clears
  last_error but the row still matches config params in the LEFT JOIN,
  making the doc permanently invisible to find_pending_documents.
  DELETE removes the row entirely so the LEFT JOIN returns NULL.
  Regression test added (old_update_approach_leaves_doc_invisible).

- Add chunking forward-progress guard: after floor_char_boundary()
  rounds backward, ensure start advances by at least one full
  character to prevent infinite loops on multi-byte sequences
  (box-drawing chars, smart quotes). Test cases cover the exact
  patterns that caused production hangs on document 18526.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-06 22:42:08 -05:00

580 lines
21 KiB
Rust

use std::collections::{HashMap, HashSet};
use futures::future::join_all;
use rusqlite::Connection;
use sha2::{Digest, Sha256};
use tracing::{debug, info, instrument, warn};
use crate::core::error::Result;
use crate::core::shutdown::ShutdownSignal;
use crate::embedding::change_detector::{count_pending_documents, find_pending_documents};
use crate::embedding::chunk_ids::{CHUNK_ROWID_MULTIPLIER, encode_rowid};
use crate::embedding::chunking::{CHUNK_MAX_BYTES, EXPECTED_DIMS, split_into_chunks};
use crate::embedding::ollama::OllamaClient;
const BATCH_SIZE: usize = 32;
const DB_PAGE_SIZE: usize = 500;
pub const DEFAULT_EMBED_CONCURRENCY: usize = 4;
#[derive(Debug, Default)]
pub struct EmbedResult {
pub chunks_embedded: usize,
pub docs_embedded: usize,
pub failed: usize,
pub skipped: usize,
}
struct ChunkWork {
doc_id: i64,
chunk_index: usize,
total_chunks: usize,
doc_hash: String,
chunk_hash: String,
text: String,
}
#[instrument(skip(conn, client, progress_callback, signal), fields(%model_name, items_processed, items_skipped, errors))]
pub async fn embed_documents(
conn: &Connection,
client: &OllamaClient,
model_name: &str,
concurrency: usize,
progress_callback: Option<Box<dyn Fn(usize, usize)>>,
signal: &ShutdownSignal,
) -> Result<EmbedResult> {
let total = count_pending_documents(conn, model_name)? as usize;
let mut result = EmbedResult::default();
let mut last_id: i64 = 0;
let mut processed: usize = 0;
if total == 0 {
return Ok(result);
}
info!(total, "Starting embedding pipeline");
loop {
if signal.is_cancelled() {
info!("Shutdown requested, stopping embedding pipeline");
break;
}
info!(last_id, "Querying pending documents...");
let pending = find_pending_documents(conn, DB_PAGE_SIZE, last_id, model_name)?;
if pending.is_empty() {
break;
}
info!(
count = pending.len(),
"Found pending documents, starting page"
);
conn.execute_batch("SAVEPOINT embed_page")?;
let page_result = embed_page(
conn,
client,
model_name,
concurrency,
&pending,
&mut result,
&mut last_id,
&mut processed,
total,
&progress_callback,
signal,
)
.await;
match page_result {
Ok(()) if signal.is_cancelled() => {
let _ = conn.execute_batch("ROLLBACK TO embed_page; RELEASE embed_page");
info!("Rolled back incomplete page to preserve data integrity");
}
Ok(()) => {
conn.execute_batch("RELEASE embed_page")?;
let _ = conn.execute_batch("PRAGMA wal_checkpoint(PASSIVE)");
info!(
chunks_embedded = result.chunks_embedded,
failed = result.failed,
skipped = result.skipped,
total,
"Page complete"
);
}
Err(e) => {
let _ = conn.execute_batch("ROLLBACK TO embed_page; RELEASE embed_page");
return Err(e);
}
}
}
info!(
chunks_embedded = result.chunks_embedded,
failed = result.failed,
skipped = result.skipped,
"Embedding pipeline complete"
);
tracing::Span::current().record("items_processed", result.chunks_embedded);
tracing::Span::current().record("items_skipped", result.skipped);
tracing::Span::current().record("errors", result.failed);
Ok(result)
}
#[allow(clippy::too_many_arguments)]
async fn embed_page(
conn: &Connection,
client: &OllamaClient,
model_name: &str,
concurrency: usize,
pending: &[crate::embedding::change_detector::PendingDocument],
result: &mut EmbedResult,
last_id: &mut i64,
processed: &mut usize,
total: usize,
progress_callback: &Option<Box<dyn Fn(usize, usize)>>,
signal: &ShutdownSignal,
) -> Result<()> {
let mut all_chunks: Vec<ChunkWork> = Vec::with_capacity(pending.len() * 3);
let mut page_normal_docs: usize = 0;
let mut chunks_needed: HashMap<i64, usize> = HashMap::with_capacity(pending.len());
let mut chunks_stored: HashMap<i64, usize> = HashMap::with_capacity(pending.len());
debug!(count = pending.len(), "Starting chunking loop");
for doc in pending {
*last_id = doc.document_id;
if doc.content_text.is_empty() {
record_embedding_error(
conn,
doc.document_id,
0,
&doc.content_hash,
"empty",
model_name,
"Document has empty content",
)?;
result.skipped += 1;
*processed += 1;
continue;
}
if page_normal_docs.is_multiple_of(50) {
debug!(
doc_id = doc.document_id,
doc_num = page_normal_docs,
content_bytes = doc.content_text.len(),
"Chunking document"
);
}
if page_normal_docs.is_multiple_of(100) {
info!(
doc_id = doc.document_id,
content_bytes = doc.content_text.len(),
docs_so_far = page_normal_docs,
"Chunking document"
);
}
let chunks = split_into_chunks(&doc.content_text);
debug!(
doc_id = doc.document_id,
chunk_count = chunks.len(),
"Chunked"
);
let total_chunks = chunks.len();
if total_chunks as i64 > CHUNK_ROWID_MULTIPLIER {
warn!(
doc_id = doc.document_id,
chunk_count = total_chunks,
max = CHUNK_ROWID_MULTIPLIER,
"Document produces too many chunks, skipping to prevent rowid collision"
);
record_embedding_error(
conn,
doc.document_id,
0,
&doc.content_hash,
"overflow-sentinel",
model_name,
&format!(
"Document produces {} chunks, exceeding max {}",
total_chunks, CHUNK_ROWID_MULTIPLIER
),
)?;
result.skipped += 1;
*processed += 1;
if let Some(cb) = progress_callback {
cb(*processed, total);
}
continue;
}
chunks_needed.insert(doc.document_id, total_chunks);
for (chunk_index, text) in chunks {
all_chunks.push(ChunkWork {
doc_id: doc.document_id,
chunk_index,
total_chunks,
doc_hash: doc.content_hash.clone(),
chunk_hash: sha256_hash(&text),
text,
});
}
page_normal_docs += 1;
}
debug!(total_chunks = all_chunks.len(), "Chunking loop done");
let mut cleared_docs: HashSet<i64> = HashSet::with_capacity(pending.len());
let mut embed_buf: Vec<u8> = Vec::with_capacity(EXPECTED_DIMS * 4);
// Split chunks into batches, then process batches in concurrent groups
let batches: Vec<&[ChunkWork]> = all_chunks.chunks(BATCH_SIZE).collect();
debug!(
batches = batches.len(),
concurrency, "Starting Ollama requests"
);
info!(
chunks = all_chunks.len(),
batches = batches.len(),
docs = page_normal_docs,
"Chunking complete, starting Ollama requests"
);
info!(
batches = batches.len(),
concurrency, "About to start Ollama request loop"
);
for (group_idx, concurrent_group) in batches.chunks(concurrency).enumerate() {
debug!(group_idx, "Starting concurrent group");
if signal.is_cancelled() {
info!("Shutdown requested during embedding, stopping mid-page");
break;
}
// Phase 1: Collect texts (must outlive the futures)
let batch_texts: Vec<Vec<&str>> = concurrent_group
.iter()
.map(|batch| batch.iter().map(|c| c.text.as_str()).collect())
.collect();
// Phase 2: Fire concurrent HTTP requests to Ollama
let futures: Vec<_> = batch_texts
.iter()
.map(|texts| client.embed_batch(texts))
.collect();
let api_results = join_all(futures).await;
debug!(
group_idx,
results = api_results.len(),
"Ollama group complete"
);
// Phase 3: Serial DB writes for each batch result
for (batch, api_result) in concurrent_group.iter().zip(api_results) {
match api_result {
Ok(embeddings) => {
for (i, embedding) in embeddings.iter().enumerate() {
if i >= batch.len() {
break;
}
let chunk = &batch[i];
if embedding.len() != EXPECTED_DIMS {
warn!(
doc_id = chunk.doc_id,
chunk_index = chunk.chunk_index,
got_dims = embedding.len(),
expected = EXPECTED_DIMS,
"Dimension mismatch, skipping"
);
record_embedding_error(
conn,
chunk.doc_id,
chunk.chunk_index,
&chunk.doc_hash,
&chunk.chunk_hash,
model_name,
&format!(
"Dimension mismatch: got {}, expected {}",
embedding.len(),
EXPECTED_DIMS
),
)?;
result.failed += 1;
continue;
}
if !cleared_docs.contains(&chunk.doc_id) {
clear_document_embeddings(conn, chunk.doc_id)?;
cleared_docs.insert(chunk.doc_id);
}
store_embedding(
conn,
chunk.doc_id,
chunk.chunk_index,
&chunk.doc_hash,
&chunk.chunk_hash,
model_name,
embedding,
chunk.total_chunks,
&mut embed_buf,
)?;
result.chunks_embedded += 1;
*chunks_stored.entry(chunk.doc_id).or_insert(0) += 1;
}
// Record errors for chunks that Ollama silently dropped
if embeddings.len() < batch.len() {
warn!(
returned = embeddings.len(),
expected = batch.len(),
"Ollama returned fewer embeddings than inputs"
);
for chunk in &batch[embeddings.len()..] {
record_embedding_error(
conn,
chunk.doc_id,
chunk.chunk_index,
&chunk.doc_hash,
&chunk.chunk_hash,
model_name,
&format!(
"Batch mismatch: got {} of {} embeddings",
embeddings.len(),
batch.len()
),
)?;
result.failed += 1;
}
}
}
Err(e) => {
let err_str = e.to_string();
let err_lower = err_str.to_lowercase();
let is_context_error = err_lower.contains("context length")
|| err_lower.contains("too long")
|| err_lower.contains("maximum context")
|| err_lower.contains("token limit")
|| err_lower.contains("exceeds")
|| (err_lower.contains("413") && err_lower.contains("http"));
if is_context_error && batch.len() > 1 {
warn!(
"Batch failed with context length error, retrying chunks individually"
);
for chunk in *batch {
match client.embed_batch(&[chunk.text.as_str()]).await {
Ok(embeddings)
if !embeddings.is_empty()
&& embeddings[0].len() == EXPECTED_DIMS =>
{
if !cleared_docs.contains(&chunk.doc_id) {
clear_document_embeddings(conn, chunk.doc_id)?;
cleared_docs.insert(chunk.doc_id);
}
store_embedding(
conn,
chunk.doc_id,
chunk.chunk_index,
&chunk.doc_hash,
&chunk.chunk_hash,
model_name,
&embeddings[0],
chunk.total_chunks,
&mut embed_buf,
)?;
result.chunks_embedded += 1;
*chunks_stored.entry(chunk.doc_id).or_insert(0) += 1;
}
Ok(embeddings) => {
let got_dims = embeddings.first().map_or(0, std::vec::Vec::len);
let reason = format!(
"Retry failed: got {} embeddings, first has {} dims (expected {})",
embeddings.len(),
got_dims,
EXPECTED_DIMS
);
warn!(
doc_id = chunk.doc_id,
chunk_index = chunk.chunk_index,
%reason,
"Chunk retry returned unexpected result"
);
record_embedding_error(
conn,
chunk.doc_id,
chunk.chunk_index,
&chunk.doc_hash,
&chunk.chunk_hash,
model_name,
&reason,
)?;
result.failed += 1;
}
Err(retry_err) => {
let reason = format!("Retry failed: {}", retry_err);
warn!(
doc_id = chunk.doc_id,
chunk_index = chunk.chunk_index,
chunk_bytes = chunk.text.len(),
error = %retry_err,
"Chunk retry request failed"
);
record_embedding_error(
conn,
chunk.doc_id,
chunk.chunk_index,
&chunk.doc_hash,
&chunk.chunk_hash,
model_name,
&reason,
)?;
result.failed += 1;
}
}
}
} else {
warn!(error = %e, "Batch embedding failed");
for chunk in *batch {
record_embedding_error(
conn,
chunk.doc_id,
chunk.chunk_index,
&chunk.doc_hash,
&chunk.chunk_hash,
model_name,
&e.to_string(),
)?;
result.failed += 1;
}
}
}
}
}
}
// Count docs where all chunks were successfully stored
for (doc_id, needed) in &chunks_needed {
if chunks_stored.get(doc_id).copied().unwrap_or(0) == *needed {
result.docs_embedded += 1;
}
}
*processed += page_normal_docs;
if let Some(cb) = progress_callback {
cb(*processed, total);
}
Ok(())
}
fn clear_document_embeddings(conn: &Connection, document_id: i64) -> Result<()> {
conn.prepare_cached("DELETE FROM embedding_metadata WHERE document_id = ?1")?
.execute([document_id])?;
let start_rowid = encode_rowid(document_id, 0);
let end_rowid = encode_rowid(document_id + 1, 0);
conn.prepare_cached("DELETE FROM embeddings WHERE rowid >= ?1 AND rowid < ?2")?
.execute(rusqlite::params![start_rowid, end_rowid])?;
Ok(())
}
#[allow(clippy::too_many_arguments)]
fn store_embedding(
conn: &Connection,
doc_id: i64,
chunk_index: usize,
doc_hash: &str,
chunk_hash: &str,
model_name: &str,
embedding: &[f32],
total_chunks: usize,
embed_buf: &mut Vec<u8>,
) -> Result<()> {
let rowid = encode_rowid(doc_id, chunk_index as i64);
embed_buf.clear();
embed_buf.reserve(embedding.len() * 4);
for f in embedding {
embed_buf.extend_from_slice(&f.to_le_bytes());
}
conn.prepare_cached("INSERT OR REPLACE INTO embeddings (rowid, embedding) VALUES (?1, ?2)")?
.execute(rusqlite::params![rowid, &embed_buf[..]])?;
let chunk_count: Option<i64> = if chunk_index == 0 {
Some(total_chunks as i64)
} else {
None
};
let now = chrono::Utc::now().timestamp_millis();
conn.prepare_cached(
"INSERT OR REPLACE INTO embedding_metadata
(document_id, chunk_index, model, dims, document_hash, chunk_hash,
created_at, attempt_count, last_error, chunk_max_bytes, chunk_count)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, 1, NULL, ?8, ?9)",
)?
.execute(rusqlite::params![
doc_id,
chunk_index as i64,
model_name,
EXPECTED_DIMS as i64,
doc_hash,
chunk_hash,
now,
CHUNK_MAX_BYTES as i64,
chunk_count
])?;
Ok(())
}
pub(crate) fn record_embedding_error(
conn: &Connection,
doc_id: i64,
chunk_index: usize,
doc_hash: &str,
chunk_hash: &str,
model_name: &str,
error: &str,
) -> Result<()> {
let now = chrono::Utc::now().timestamp_millis();
conn.prepare_cached(
"INSERT INTO embedding_metadata
(document_id, chunk_index, model, dims, document_hash, chunk_hash,
created_at, attempt_count, last_error, last_attempt_at, chunk_max_bytes)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, 1, ?8, ?7, ?9)
ON CONFLICT(document_id, chunk_index) DO UPDATE SET
attempt_count = embedding_metadata.attempt_count + 1,
last_error = ?8,
last_attempt_at = ?7,
chunk_max_bytes = ?9",
)?
.execute(rusqlite::params![
doc_id,
chunk_index as i64,
model_name,
EXPECTED_DIMS as i64,
doc_hash,
chunk_hash,
now,
error,
CHUNK_MAX_BYTES as i64
])?;
Ok(())
}
fn sha256_hash(input: &str) -> String {
let mut hasher = Sha256::new();
hasher.update(input.as_bytes());
format!("{:x}", hasher.finalize())
}