is_multiple_of(N) returns true for 0, which caused debug/info progress messages to fire at doc_num=0 (the start of every page) rather than only at the intended 50/100 milestones. Add != 0 check to both the debug (every 50) and info (every 100) log sites. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
786 lines
27 KiB
Rust
786 lines
27 KiB
Rust
use std::collections::{HashMap, HashSet};
|
|
use std::sync::Arc;
|
|
|
|
use futures::future::join_all;
|
|
use rusqlite::Connection;
|
|
use sha2::{Digest, Sha256};
|
|
use tracing::{debug, info, instrument, warn};
|
|
|
|
use crate::core::error::Result;
|
|
use crate::core::shutdown::ShutdownSignal;
|
|
use crate::embedding::change_detector::{count_pending_documents, find_pending_documents};
|
|
use crate::embedding::chunks::{
|
|
CHUNK_MAX_BYTES, CHUNK_ROWID_MULTIPLIER, EXPECTED_DIMS, encode_rowid, split_into_chunks,
|
|
};
|
|
use crate::embedding::ollama::OllamaClient;
|
|
|
|
const BATCH_SIZE: usize = 32;
|
|
const DB_PAGE_SIZE: usize = 500;
|
|
pub const DEFAULT_EMBED_CONCURRENCY: usize = 4;
|
|
|
|
#[derive(Debug, Default)]
|
|
pub struct EmbedResult {
|
|
pub chunks_embedded: usize,
|
|
pub docs_embedded: usize,
|
|
pub failed: usize,
|
|
pub skipped: usize,
|
|
}
|
|
|
|
struct ChunkWork {
|
|
doc_id: i64,
|
|
chunk_index: usize,
|
|
total_chunks: usize,
|
|
doc_hash: Arc<str>,
|
|
chunk_hash: String,
|
|
text: String,
|
|
}
|
|
|
|
#[instrument(skip(conn, client, progress_callback, signal), fields(%model_name, items_processed, items_skipped, errors))]
|
|
pub async fn embed_documents(
|
|
conn: &Connection,
|
|
client: &OllamaClient,
|
|
model_name: &str,
|
|
concurrency: usize,
|
|
progress_callback: Option<Box<dyn Fn(usize, usize)>>,
|
|
signal: &ShutdownSignal,
|
|
) -> Result<EmbedResult> {
|
|
let total = count_pending_documents(conn, model_name)? as usize;
|
|
let mut result = EmbedResult::default();
|
|
let mut last_id: i64 = 0;
|
|
let mut processed: usize = 0;
|
|
|
|
if total == 0 {
|
|
return Ok(result);
|
|
}
|
|
|
|
info!(total, "Starting embedding pipeline");
|
|
|
|
loop {
|
|
if signal.is_cancelled() {
|
|
info!("Shutdown requested, stopping embedding pipeline");
|
|
break;
|
|
}
|
|
|
|
info!(last_id, "Querying pending documents...");
|
|
let pending = find_pending_documents(conn, DB_PAGE_SIZE, last_id, model_name)?;
|
|
if pending.is_empty() {
|
|
break;
|
|
}
|
|
info!(
|
|
count = pending.len(),
|
|
"Found pending documents, starting page"
|
|
);
|
|
|
|
conn.execute_batch("SAVEPOINT embed_page")?;
|
|
let page_result = embed_page(
|
|
conn,
|
|
client,
|
|
model_name,
|
|
concurrency,
|
|
&pending,
|
|
&mut result,
|
|
&mut last_id,
|
|
&mut processed,
|
|
total,
|
|
&progress_callback,
|
|
signal,
|
|
)
|
|
.await;
|
|
match page_result {
|
|
Ok(()) if signal.is_cancelled() => {
|
|
let _ = conn.execute_batch("ROLLBACK TO embed_page; RELEASE embed_page");
|
|
info!("Rolled back incomplete page to preserve data integrity");
|
|
}
|
|
Ok(()) => {
|
|
conn.execute_batch("RELEASE embed_page")?;
|
|
let _ = conn.execute_batch("PRAGMA wal_checkpoint(PASSIVE)");
|
|
info!(
|
|
chunks_embedded = result.chunks_embedded,
|
|
failed = result.failed,
|
|
skipped = result.skipped,
|
|
total,
|
|
"Page complete"
|
|
);
|
|
}
|
|
Err(e) => {
|
|
let _ = conn.execute_batch("ROLLBACK TO embed_page; RELEASE embed_page");
|
|
return Err(e);
|
|
}
|
|
}
|
|
}
|
|
|
|
info!(
|
|
chunks_embedded = result.chunks_embedded,
|
|
failed = result.failed,
|
|
skipped = result.skipped,
|
|
"Embedding pipeline complete"
|
|
);
|
|
|
|
tracing::Span::current().record("items_processed", result.chunks_embedded);
|
|
tracing::Span::current().record("items_skipped", result.skipped);
|
|
tracing::Span::current().record("errors", result.failed);
|
|
|
|
Ok(result)
|
|
}
|
|
|
|
#[allow(clippy::too_many_arguments)]
|
|
async fn embed_page(
|
|
conn: &Connection,
|
|
client: &OllamaClient,
|
|
model_name: &str,
|
|
concurrency: usize,
|
|
pending: &[crate::embedding::change_detector::PendingDocument],
|
|
result: &mut EmbedResult,
|
|
last_id: &mut i64,
|
|
processed: &mut usize,
|
|
total: usize,
|
|
progress_callback: &Option<Box<dyn Fn(usize, usize)>>,
|
|
signal: &ShutdownSignal,
|
|
) -> Result<()> {
|
|
let mut all_chunks: Vec<ChunkWork> = Vec::with_capacity(pending.len() * 3);
|
|
let mut page_normal_docs: usize = 0;
|
|
let mut chunks_needed: HashMap<i64, usize> = HashMap::with_capacity(pending.len());
|
|
let mut chunks_stored: HashMap<i64, usize> = HashMap::with_capacity(pending.len());
|
|
|
|
debug!(count = pending.len(), "Starting chunking loop");
|
|
for doc in pending {
|
|
*last_id = doc.document_id;
|
|
|
|
if doc.content_text.is_empty() {
|
|
record_embedding_error(
|
|
conn,
|
|
doc.document_id,
|
|
0,
|
|
&doc.content_hash,
|
|
"empty",
|
|
model_name,
|
|
"Document has empty content",
|
|
)?;
|
|
result.skipped += 1;
|
|
*processed += 1;
|
|
continue;
|
|
}
|
|
|
|
if page_normal_docs != 0 && page_normal_docs.is_multiple_of(50) {
|
|
debug!(
|
|
doc_id = doc.document_id,
|
|
doc_num = page_normal_docs,
|
|
content_bytes = doc.content_text.len(),
|
|
"Chunking document"
|
|
);
|
|
}
|
|
if page_normal_docs != 0 && page_normal_docs.is_multiple_of(100) {
|
|
info!(
|
|
doc_id = doc.document_id,
|
|
content_bytes = doc.content_text.len(),
|
|
docs_so_far = page_normal_docs,
|
|
"Chunking document"
|
|
);
|
|
}
|
|
let chunks = split_into_chunks(&doc.content_text);
|
|
debug!(
|
|
doc_id = doc.document_id,
|
|
chunk_count = chunks.len(),
|
|
"Chunked"
|
|
);
|
|
let total_chunks = chunks.len();
|
|
|
|
if total_chunks as i64 > CHUNK_ROWID_MULTIPLIER {
|
|
warn!(
|
|
doc_id = doc.document_id,
|
|
chunk_count = total_chunks,
|
|
max = CHUNK_ROWID_MULTIPLIER,
|
|
"Document produces too many chunks, skipping to prevent rowid collision"
|
|
);
|
|
record_embedding_error(
|
|
conn,
|
|
doc.document_id,
|
|
0,
|
|
&doc.content_hash,
|
|
"overflow-sentinel",
|
|
model_name,
|
|
&format!(
|
|
"Document produces {} chunks, exceeding max {}",
|
|
total_chunks, CHUNK_ROWID_MULTIPLIER
|
|
),
|
|
)?;
|
|
result.skipped += 1;
|
|
*processed += 1;
|
|
if let Some(cb) = progress_callback {
|
|
cb(*processed, total);
|
|
}
|
|
continue;
|
|
}
|
|
|
|
chunks_needed.insert(doc.document_id, total_chunks);
|
|
|
|
let doc_hash: Arc<str> = Arc::from(doc.content_hash.as_str());
|
|
for (chunk_index, text) in chunks {
|
|
all_chunks.push(ChunkWork {
|
|
doc_id: doc.document_id,
|
|
chunk_index,
|
|
total_chunks,
|
|
doc_hash: Arc::clone(&doc_hash),
|
|
chunk_hash: sha256_hash(&text),
|
|
text,
|
|
});
|
|
}
|
|
|
|
page_normal_docs += 1;
|
|
}
|
|
|
|
debug!(total_chunks = all_chunks.len(), "Chunking loop done");
|
|
|
|
let mut cleared_docs: HashSet<i64> = HashSet::with_capacity(pending.len());
|
|
let mut embed_buf: Vec<u8> = Vec::with_capacity(EXPECTED_DIMS * 4);
|
|
|
|
// Split chunks into batches, then process batches in concurrent groups
|
|
let batches: Vec<&[ChunkWork]> = all_chunks.chunks(BATCH_SIZE).collect();
|
|
debug!(
|
|
batches = batches.len(),
|
|
concurrency, "Starting Ollama requests"
|
|
);
|
|
info!(
|
|
chunks = all_chunks.len(),
|
|
batches = batches.len(),
|
|
docs = page_normal_docs,
|
|
"Chunking complete, starting Ollama requests"
|
|
);
|
|
|
|
info!(
|
|
batches = batches.len(),
|
|
concurrency, "About to start Ollama request loop"
|
|
);
|
|
for (group_idx, concurrent_group) in batches.chunks(concurrency).enumerate() {
|
|
debug!(group_idx, "Starting concurrent group");
|
|
if signal.is_cancelled() {
|
|
info!("Shutdown requested during embedding, stopping mid-page");
|
|
break;
|
|
}
|
|
|
|
// Phase 1: Collect texts (must outlive the futures)
|
|
let batch_texts: Vec<Vec<&str>> = concurrent_group
|
|
.iter()
|
|
.map(|batch| batch.iter().map(|c| c.text.as_str()).collect())
|
|
.collect();
|
|
|
|
// Phase 2: Fire concurrent HTTP requests to Ollama
|
|
let futures: Vec<_> = batch_texts
|
|
.iter()
|
|
.map(|texts| client.embed_batch(texts))
|
|
.collect();
|
|
let api_results = join_all(futures).await;
|
|
debug!(
|
|
group_idx,
|
|
results = api_results.len(),
|
|
"Ollama group complete"
|
|
);
|
|
|
|
// Phase 3: Serial DB writes for each batch result
|
|
for (batch, api_result) in concurrent_group.iter().zip(api_results) {
|
|
match api_result {
|
|
Ok(embeddings) => {
|
|
for (i, embedding) in embeddings.iter().enumerate() {
|
|
if i >= batch.len() {
|
|
break;
|
|
}
|
|
let chunk = &batch[i];
|
|
|
|
if embedding.len() != EXPECTED_DIMS {
|
|
warn!(
|
|
doc_id = chunk.doc_id,
|
|
chunk_index = chunk.chunk_index,
|
|
got_dims = embedding.len(),
|
|
expected = EXPECTED_DIMS,
|
|
"Dimension mismatch, skipping"
|
|
);
|
|
record_embedding_error(
|
|
conn,
|
|
chunk.doc_id,
|
|
chunk.chunk_index,
|
|
&chunk.doc_hash,
|
|
&chunk.chunk_hash,
|
|
model_name,
|
|
&format!(
|
|
"Dimension mismatch: got {}, expected {}",
|
|
embedding.len(),
|
|
EXPECTED_DIMS
|
|
),
|
|
)?;
|
|
result.failed += 1;
|
|
continue;
|
|
}
|
|
|
|
if !cleared_docs.contains(&chunk.doc_id) {
|
|
clear_document_embeddings(conn, chunk.doc_id)?;
|
|
cleared_docs.insert(chunk.doc_id);
|
|
}
|
|
|
|
store_embedding(
|
|
conn,
|
|
chunk.doc_id,
|
|
chunk.chunk_index,
|
|
&chunk.doc_hash,
|
|
&chunk.chunk_hash,
|
|
model_name,
|
|
embedding,
|
|
chunk.total_chunks,
|
|
&mut embed_buf,
|
|
)?;
|
|
result.chunks_embedded += 1;
|
|
*chunks_stored.entry(chunk.doc_id).or_insert(0) += 1;
|
|
}
|
|
|
|
// Record errors for chunks that Ollama silently dropped
|
|
if embeddings.len() < batch.len() {
|
|
warn!(
|
|
returned = embeddings.len(),
|
|
expected = batch.len(),
|
|
"Ollama returned fewer embeddings than inputs"
|
|
);
|
|
for chunk in &batch[embeddings.len()..] {
|
|
record_embedding_error(
|
|
conn,
|
|
chunk.doc_id,
|
|
chunk.chunk_index,
|
|
&chunk.doc_hash,
|
|
&chunk.chunk_hash,
|
|
model_name,
|
|
&format!(
|
|
"Batch mismatch: got {} of {} embeddings",
|
|
embeddings.len(),
|
|
batch.len()
|
|
),
|
|
)?;
|
|
result.failed += 1;
|
|
}
|
|
}
|
|
}
|
|
Err(e) => {
|
|
let err_str = e.to_string();
|
|
let err_lower = err_str.to_lowercase();
|
|
let is_context_error = err_lower.contains("context length")
|
|
|| err_lower.contains("too long")
|
|
|| err_lower.contains("maximum context")
|
|
|| err_lower.contains("token limit")
|
|
|| err_lower.contains("exceeds")
|
|
|| (err_lower.contains("413") && err_lower.contains("http"));
|
|
|
|
if is_context_error && batch.len() > 1 {
|
|
warn!(
|
|
"Batch failed with context length error, retrying chunks individually"
|
|
);
|
|
for chunk in *batch {
|
|
match client.embed_batch(&[chunk.text.as_str()]).await {
|
|
Ok(embeddings)
|
|
if !embeddings.is_empty()
|
|
&& embeddings[0].len() == EXPECTED_DIMS =>
|
|
{
|
|
if !cleared_docs.contains(&chunk.doc_id) {
|
|
clear_document_embeddings(conn, chunk.doc_id)?;
|
|
cleared_docs.insert(chunk.doc_id);
|
|
}
|
|
|
|
store_embedding(
|
|
conn,
|
|
chunk.doc_id,
|
|
chunk.chunk_index,
|
|
&chunk.doc_hash,
|
|
&chunk.chunk_hash,
|
|
model_name,
|
|
&embeddings[0],
|
|
chunk.total_chunks,
|
|
&mut embed_buf,
|
|
)?;
|
|
result.chunks_embedded += 1;
|
|
*chunks_stored.entry(chunk.doc_id).or_insert(0) += 1;
|
|
}
|
|
Ok(embeddings) => {
|
|
let got_dims = embeddings.first().map_or(0, std::vec::Vec::len);
|
|
let reason = format!(
|
|
"Retry failed: got {} embeddings, first has {} dims (expected {})",
|
|
embeddings.len(),
|
|
got_dims,
|
|
EXPECTED_DIMS
|
|
);
|
|
warn!(
|
|
doc_id = chunk.doc_id,
|
|
chunk_index = chunk.chunk_index,
|
|
%reason,
|
|
"Chunk retry returned unexpected result"
|
|
);
|
|
record_embedding_error(
|
|
conn,
|
|
chunk.doc_id,
|
|
chunk.chunk_index,
|
|
&chunk.doc_hash,
|
|
&chunk.chunk_hash,
|
|
model_name,
|
|
&reason,
|
|
)?;
|
|
result.failed += 1;
|
|
}
|
|
Err(retry_err) => {
|
|
let reason = format!("Retry failed: {}", retry_err);
|
|
warn!(
|
|
doc_id = chunk.doc_id,
|
|
chunk_index = chunk.chunk_index,
|
|
chunk_bytes = chunk.text.len(),
|
|
error = %retry_err,
|
|
"Chunk retry request failed"
|
|
);
|
|
record_embedding_error(
|
|
conn,
|
|
chunk.doc_id,
|
|
chunk.chunk_index,
|
|
&chunk.doc_hash,
|
|
&chunk.chunk_hash,
|
|
model_name,
|
|
&reason,
|
|
)?;
|
|
result.failed += 1;
|
|
}
|
|
}
|
|
}
|
|
} else {
|
|
warn!(error = %e, "Batch embedding failed");
|
|
for chunk in *batch {
|
|
record_embedding_error(
|
|
conn,
|
|
chunk.doc_id,
|
|
chunk.chunk_index,
|
|
&chunk.doc_hash,
|
|
&chunk.chunk_hash,
|
|
model_name,
|
|
&e.to_string(),
|
|
)?;
|
|
result.failed += 1;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Count docs where all chunks were successfully stored
|
|
for (doc_id, needed) in &chunks_needed {
|
|
if chunks_stored.get(doc_id).copied().unwrap_or(0) == *needed {
|
|
result.docs_embedded += 1;
|
|
}
|
|
}
|
|
|
|
*processed += page_normal_docs;
|
|
if let Some(cb) = progress_callback {
|
|
cb(*processed, total);
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
fn clear_document_embeddings(conn: &Connection, document_id: i64) -> Result<()> {
|
|
conn.prepare_cached("DELETE FROM embedding_metadata WHERE document_id = ?1")?
|
|
.execute([document_id])?;
|
|
|
|
let start_rowid = encode_rowid(document_id, 0);
|
|
let end_rowid = encode_rowid(document_id + 1, 0);
|
|
conn.prepare_cached("DELETE FROM embeddings WHERE rowid >= ?1 AND rowid < ?2")?
|
|
.execute(rusqlite::params![start_rowid, end_rowid])?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
#[allow(clippy::too_many_arguments)]
|
|
fn store_embedding(
|
|
conn: &Connection,
|
|
doc_id: i64,
|
|
chunk_index: usize,
|
|
doc_hash: &str,
|
|
chunk_hash: &str,
|
|
model_name: &str,
|
|
embedding: &[f32],
|
|
total_chunks: usize,
|
|
embed_buf: &mut Vec<u8>,
|
|
) -> Result<()> {
|
|
let rowid = encode_rowid(doc_id, chunk_index as i64);
|
|
|
|
embed_buf.clear();
|
|
for f in embedding {
|
|
embed_buf.extend_from_slice(&f.to_le_bytes());
|
|
}
|
|
|
|
conn.prepare_cached("INSERT OR REPLACE INTO embeddings (rowid, embedding) VALUES (?1, ?2)")?
|
|
.execute(rusqlite::params![rowid, &embed_buf[..]])?;
|
|
|
|
let chunk_count: Option<i64> = if chunk_index == 0 {
|
|
Some(total_chunks as i64)
|
|
} else {
|
|
None
|
|
};
|
|
|
|
let now = chrono::Utc::now().timestamp_millis();
|
|
conn.prepare_cached(
|
|
"INSERT OR REPLACE INTO embedding_metadata
|
|
(document_id, chunk_index, model, dims, document_hash, chunk_hash,
|
|
created_at, attempt_count, last_error, chunk_max_bytes, chunk_count)
|
|
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, 1, NULL, ?8, ?9)",
|
|
)?
|
|
.execute(rusqlite::params![
|
|
doc_id,
|
|
chunk_index as i64,
|
|
model_name,
|
|
EXPECTED_DIMS as i64,
|
|
doc_hash,
|
|
chunk_hash,
|
|
now,
|
|
CHUNK_MAX_BYTES as i64,
|
|
chunk_count
|
|
])?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
pub(crate) fn record_embedding_error(
|
|
conn: &Connection,
|
|
doc_id: i64,
|
|
chunk_index: usize,
|
|
doc_hash: &str,
|
|
chunk_hash: &str,
|
|
model_name: &str,
|
|
error: &str,
|
|
) -> Result<()> {
|
|
let now = chrono::Utc::now().timestamp_millis();
|
|
conn.prepare_cached(
|
|
"INSERT INTO embedding_metadata
|
|
(document_id, chunk_index, model, dims, document_hash, chunk_hash,
|
|
created_at, attempt_count, last_error, last_attempt_at, chunk_max_bytes)
|
|
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, 1, ?8, ?7, ?9)
|
|
ON CONFLICT(document_id, chunk_index) DO UPDATE SET
|
|
attempt_count = embedding_metadata.attempt_count + 1,
|
|
last_error = ?8,
|
|
last_attempt_at = ?7,
|
|
chunk_max_bytes = ?9",
|
|
)?
|
|
.execute(rusqlite::params![
|
|
doc_id,
|
|
chunk_index as i64,
|
|
model_name,
|
|
EXPECTED_DIMS as i64,
|
|
doc_hash,
|
|
chunk_hash,
|
|
now,
|
|
error,
|
|
CHUNK_MAX_BYTES as i64
|
|
])?;
|
|
Ok(())
|
|
}
|
|
|
|
fn sha256_hash(input: &str) -> String {
|
|
let mut hasher = Sha256::new();
|
|
hasher.update(input.as_bytes());
|
|
format!("{:x}", hasher.finalize())
|
|
}
|
|
|
|
#[derive(Debug, Default)]
|
|
pub struct EmbedForIdsResult {
|
|
pub chunks_embedded: usize,
|
|
pub docs_embedded: usize,
|
|
pub failed: usize,
|
|
pub skipped: usize,
|
|
}
|
|
|
|
/// Embed only the documents with the given IDs, skipping any that are
|
|
/// already embedded with matching config (model, dims, chunk size, hash).
|
|
pub async fn embed_documents_by_ids(
|
|
conn: &Connection,
|
|
client: &OllamaClient,
|
|
model_name: &str,
|
|
concurrency: usize,
|
|
document_ids: &[i64],
|
|
signal: &ShutdownSignal,
|
|
) -> Result<EmbedForIdsResult> {
|
|
let mut result = EmbedForIdsResult::default();
|
|
|
|
if document_ids.is_empty() {
|
|
return Ok(result);
|
|
}
|
|
|
|
if signal.is_cancelled() {
|
|
return Ok(result);
|
|
}
|
|
|
|
// Load documents for the specified IDs, filtering out already-embedded
|
|
let pending = find_documents_by_ids(conn, document_ids, model_name)?;
|
|
|
|
if pending.is_empty() {
|
|
result.skipped = document_ids.len();
|
|
return Ok(result);
|
|
}
|
|
|
|
let skipped_count = document_ids.len() - pending.len();
|
|
result.skipped = skipped_count;
|
|
|
|
info!(
|
|
requested = document_ids.len(),
|
|
pending = pending.len(),
|
|
skipped = skipped_count,
|
|
"Scoped embedding: processing documents by ID"
|
|
);
|
|
|
|
// Use the same SAVEPOINT + embed_page pattern as the main pipeline
|
|
let mut last_id: i64 = 0;
|
|
let mut processed: usize = 0;
|
|
let total = pending.len();
|
|
let mut page_stats = EmbedResult::default();
|
|
|
|
conn.execute_batch("SAVEPOINT embed_by_ids")?;
|
|
let page_result = embed_page(
|
|
conn,
|
|
client,
|
|
model_name,
|
|
concurrency,
|
|
&pending,
|
|
&mut page_stats,
|
|
&mut last_id,
|
|
&mut processed,
|
|
total,
|
|
&None,
|
|
signal,
|
|
)
|
|
.await;
|
|
|
|
match page_result {
|
|
Ok(()) if signal.is_cancelled() => {
|
|
let _ = conn.execute_batch("ROLLBACK TO embed_by_ids; RELEASE embed_by_ids");
|
|
info!("Rolled back scoped embed page due to cancellation");
|
|
}
|
|
Ok(()) => {
|
|
conn.execute_batch("RELEASE embed_by_ids")?;
|
|
|
|
// Count actual results from DB
|
|
let (chunks, docs) = count_embedded_results(conn, &pending)?;
|
|
result.chunks_embedded = chunks;
|
|
result.docs_embedded = docs;
|
|
result.failed = page_stats.failed;
|
|
}
|
|
Err(e) => {
|
|
let _ = conn.execute_batch("ROLLBACK TO embed_by_ids; RELEASE embed_by_ids");
|
|
return Err(e);
|
|
}
|
|
}
|
|
|
|
info!(
|
|
chunks_embedded = result.chunks_embedded,
|
|
docs_embedded = result.docs_embedded,
|
|
failed = result.failed,
|
|
skipped = result.skipped,
|
|
"Scoped embedding complete"
|
|
);
|
|
|
|
Ok(result)
|
|
}
|
|
|
|
/// Load documents by specific IDs, filtering out those already embedded
|
|
/// with matching config (same logic as `find_pending_documents` but scoped).
|
|
fn find_documents_by_ids(
|
|
conn: &Connection,
|
|
document_ids: &[i64],
|
|
model_name: &str,
|
|
) -> Result<Vec<crate::embedding::change_detector::PendingDocument>> {
|
|
use crate::embedding::chunks::{CHUNK_MAX_BYTES, EXPECTED_DIMS};
|
|
|
|
if document_ids.is_empty() {
|
|
return Ok(Vec::new());
|
|
}
|
|
|
|
// Build IN clause with placeholders
|
|
let placeholders: Vec<String> = (0..document_ids.len())
|
|
.map(|i| format!("?{}", i + 1))
|
|
.collect();
|
|
let in_clause = placeholders.join(", ");
|
|
|
|
let sql = format!(
|
|
r#"
|
|
SELECT d.id, d.content_text, d.content_hash
|
|
FROM documents d
|
|
LEFT JOIN embedding_metadata em
|
|
ON em.document_id = d.id AND em.chunk_index = 0
|
|
WHERE d.id IN ({in_clause})
|
|
AND (
|
|
em.document_id IS NULL
|
|
OR em.document_hash != d.content_hash
|
|
OR em.chunk_max_bytes IS NULL
|
|
OR em.chunk_max_bytes != ?{chunk_bytes_idx}
|
|
OR em.model != ?{model_idx}
|
|
OR em.dims != ?{dims_idx}
|
|
)
|
|
ORDER BY d.id
|
|
"#,
|
|
in_clause = in_clause,
|
|
chunk_bytes_idx = document_ids.len() + 1,
|
|
model_idx = document_ids.len() + 2,
|
|
dims_idx = document_ids.len() + 3,
|
|
);
|
|
|
|
let mut stmt = conn.prepare(&sql)?;
|
|
|
|
// Build params: document_ids... then chunk_max_bytes, model, dims
|
|
let mut params: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
|
|
for id in document_ids {
|
|
params.push(Box::new(*id));
|
|
}
|
|
params.push(Box::new(CHUNK_MAX_BYTES as i64));
|
|
params.push(Box::new(model_name.to_string()));
|
|
params.push(Box::new(EXPECTED_DIMS as i64));
|
|
|
|
let param_refs: Vec<&dyn rusqlite::types::ToSql> = params.iter().map(|p| p.as_ref()).collect();
|
|
|
|
let rows = stmt
|
|
.query_map(param_refs.as_slice(), |row| {
|
|
Ok(crate::embedding::change_detector::PendingDocument {
|
|
document_id: row.get(0)?,
|
|
content_text: row.get(1)?,
|
|
content_hash: row.get(2)?,
|
|
})
|
|
})?
|
|
.collect::<std::result::Result<Vec<_>, _>>()?;
|
|
|
|
Ok(rows)
|
|
}
|
|
|
|
/// Count how many chunks and complete docs were embedded for the given pending docs.
|
|
fn count_embedded_results(
|
|
conn: &Connection,
|
|
pending: &[crate::embedding::change_detector::PendingDocument],
|
|
) -> Result<(usize, usize)> {
|
|
let mut total_chunks: usize = 0;
|
|
let mut total_docs: usize = 0;
|
|
|
|
for doc in pending {
|
|
let chunk_count: i64 = conn.query_row(
|
|
"SELECT COUNT(*) FROM embedding_metadata WHERE document_id = ?1 AND last_error IS NULL",
|
|
[doc.document_id],
|
|
|row| row.get(0),
|
|
)?;
|
|
if chunk_count > 0 {
|
|
total_chunks += chunk_count as usize;
|
|
// Check if all expected chunks are present (chunk_count metadata on chunk_index=0)
|
|
let expected: Option<i64> = conn.query_row(
|
|
"SELECT chunk_count FROM embedding_metadata WHERE document_id = ?1 AND chunk_index = 0",
|
|
[doc.document_id],
|
|
|row| row.get(0),
|
|
)?;
|
|
if let Some(expected_count) = expected
|
|
&& chunk_count >= expected_count
|
|
{
|
|
total_docs += 1;
|
|
}
|
|
}
|
|
}
|
|
|
|
Ok((total_chunks, total_docs))
|
|
}
|
|
|
|
#[cfg(test)]
|
|
#[path = "pipeline_tests.rs"]
|
|
mod tests;
|