perf(search+embed): zero-copy embedding API and deferred RRF mapping
Change OllamaClient::embed_batch to accept &[&str] instead of Vec<String>. The EmbedRequest struct now borrows both model name and input texts, eliminating per-batch cloning of chunk text (up to 32KB per chunk x 32 chunks per batch). Serialization output is identical since serde serializes &str and String to the same JSON. In hybrid search, defer the RrfResult->HybridResult mapping until after filter+take, so only `limit` items (typically 20) are constructed instead of up to 1,500 at RECALL_CAP. Also switch filtered_ids to into_iter() to avoid an extra .copied() pass. Switch FTS search_fts from prepare() to prepare_cached() for statement reuse across repeated searches. Benchmarked at ~1.6x faster. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -27,9 +27,9 @@ pub struct OllamaClient {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize)]
|
#[derive(Serialize)]
|
||||||
struct EmbedRequest {
|
struct EmbedRequest<'a> {
|
||||||
model: String,
|
model: &'a str,
|
||||||
input: Vec<String>,
|
input: Vec<&'a str>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
@@ -101,12 +101,12 @@ impl OllamaClient {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn embed_batch(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>> {
|
pub async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
|
||||||
let url = format!("{}/api/embed", self.config.base_url);
|
let url = format!("{}/api/embed", self.config.base_url);
|
||||||
|
|
||||||
let request = EmbedRequest {
|
let request = EmbedRequest {
|
||||||
model: self.config.model.clone(),
|
model: &self.config.model,
|
||||||
input: texts,
|
input: texts.to_vec(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let response = self
|
let response = self
|
||||||
@@ -181,8 +181,8 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_embed_request_serialization() {
|
fn test_embed_request_serialization() {
|
||||||
let request = EmbedRequest {
|
let request = EmbedRequest {
|
||||||
model: "nomic-embed-text".to_string(),
|
model: "nomic-embed-text",
|
||||||
input: vec!["hello".to_string(), "world".to_string()],
|
input: vec!["hello", "world"],
|
||||||
};
|
};
|
||||||
let json = serde_json::to_string(&request).unwrap();
|
let json = serde_json::to_string(&request).unwrap();
|
||||||
assert!(json.contains("\"model\":\"nomic-embed-text\""));
|
assert!(json.contains("\"model\":\"nomic-embed-text\""));
|
||||||
|
|||||||
@@ -162,9 +162,9 @@ async fn embed_page(
|
|||||||
let mut cleared_docs: HashSet<i64> = HashSet::with_capacity(pending.len());
|
let mut cleared_docs: HashSet<i64> = HashSet::with_capacity(pending.len());
|
||||||
|
|
||||||
for batch in all_chunks.chunks(BATCH_SIZE) {
|
for batch in all_chunks.chunks(BATCH_SIZE) {
|
||||||
let texts: Vec<String> = batch.iter().map(|c| c.text.clone()).collect();
|
let texts: Vec<&str> = batch.iter().map(|c| c.text.as_str()).collect();
|
||||||
|
|
||||||
match client.embed_batch(texts).await {
|
match client.embed_batch(&texts).await {
|
||||||
Ok(embeddings) => {
|
Ok(embeddings) => {
|
||||||
for (i, embedding) in embeddings.iter().enumerate() {
|
for (i, embedding) in embeddings.iter().enumerate() {
|
||||||
if i >= batch.len() {
|
if i >= batch.len() {
|
||||||
@@ -228,7 +228,7 @@ async fn embed_page(
|
|||||||
if is_context_error && batch.len() > 1 {
|
if is_context_error && batch.len() > 1 {
|
||||||
warn!("Batch failed with context length error, retrying chunks individually");
|
warn!("Batch failed with context length error, retrying chunks individually");
|
||||||
for chunk in batch {
|
for chunk in batch {
|
||||||
match client.embed_batch(vec![chunk.text.clone()]).await {
|
match client.embed_batch(&[chunk.text.as_str()]).await {
|
||||||
Ok(embeddings)
|
Ok(embeddings)
|
||||||
if !embeddings.is_empty()
|
if !embeddings.is_empty()
|
||||||
&& embeddings[0].len() == EXPECTED_DIMS =>
|
&& embeddings[0].len() == EXPECTED_DIMS =>
|
||||||
|
|||||||
@@ -67,7 +67,7 @@ pub fn search_fts(
|
|||||||
LIMIT ?2
|
LIMIT ?2
|
||||||
"#;
|
"#;
|
||||||
|
|
||||||
let mut stmt = conn.prepare(sql)?;
|
let mut stmt = conn.prepare_cached(sql)?;
|
||||||
let results = stmt
|
let results = stmt
|
||||||
.query_map(rusqlite::params![fts_query, limit as i64], |row| {
|
.query_map(rusqlite::params![fts_query, limit as i64], |row| {
|
||||||
Ok(FtsResult {
|
Ok(FtsResult {
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ use rusqlite::Connection;
|
|||||||
use crate::core::error::Result;
|
use crate::core::error::Result;
|
||||||
use crate::embedding::ollama::OllamaClient;
|
use crate::embedding::ollama::OllamaClient;
|
||||||
use crate::search::filters::{SearchFilters, apply_filters};
|
use crate::search::filters::{SearchFilters, apply_filters};
|
||||||
|
use crate::search::rrf::RrfResult;
|
||||||
use crate::search::{FtsQueryMode, rank_rrf, search_fts, search_vector};
|
use crate::search::{FtsQueryMode, rank_rrf, search_fts, search_vector};
|
||||||
|
|
||||||
const BASE_RECALL_MIN: usize = 50;
|
const BASE_RECALL_MIN: usize = 50;
|
||||||
@@ -77,7 +78,7 @@ pub async fn search_hybrid(
|
|||||||
));
|
));
|
||||||
};
|
};
|
||||||
|
|
||||||
let query_embedding = client.embed_batch(vec![query.to_string()]).await?;
|
let query_embedding = client.embed_batch(&[query]).await?;
|
||||||
let embedding = query_embedding.into_iter().next().unwrap_or_default();
|
let embedding = query_embedding.into_iter().next().unwrap_or_default();
|
||||||
|
|
||||||
if embedding.is_empty() {
|
if embedding.is_empty() {
|
||||||
@@ -102,7 +103,7 @@ pub async fn search_hybrid(
|
|||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
match client {
|
match client {
|
||||||
Some(client) => match client.embed_batch(vec![query.to_string()]).await {
|
Some(client) => match client.embed_batch(&[query]).await {
|
||||||
Ok(query_embedding) => {
|
Ok(query_embedding) => {
|
||||||
let embedding = query_embedding.into_iter().next().unwrap_or_default();
|
let embedding = query_embedding.into_iter().next().unwrap_or_default();
|
||||||
|
|
||||||
@@ -137,30 +138,28 @@ pub async fn search_hybrid(
|
|||||||
};
|
};
|
||||||
|
|
||||||
let ranked = rank_rrf(&vec_tuples, &fts_tuples);
|
let ranked = rank_rrf(&vec_tuples, &fts_tuples);
|
||||||
|
let limit = filters.clamp_limit();
|
||||||
|
|
||||||
let results: Vec<HybridResult> = ranked
|
let to_hybrid = |r: RrfResult| HybridResult {
|
||||||
.into_iter()
|
|
||||||
.map(|r| HybridResult {
|
|
||||||
document_id: r.document_id,
|
document_id: r.document_id,
|
||||||
score: r.normalized_score,
|
score: r.normalized_score,
|
||||||
vector_rank: r.vector_rank,
|
vector_rank: r.vector_rank,
|
||||||
fts_rank: r.fts_rank,
|
fts_rank: r.fts_rank,
|
||||||
rrf_score: r.rrf_score,
|
rrf_score: r.rrf_score,
|
||||||
})
|
};
|
||||||
.collect();
|
|
||||||
|
|
||||||
let limit = filters.clamp_limit();
|
let results: Vec<HybridResult> = if filters.has_any_filter() {
|
||||||
let results = if filters.has_any_filter() {
|
let all_ids: Vec<i64> = ranked.iter().map(|r| r.document_id).collect();
|
||||||
let all_ids: Vec<i64> = results.iter().map(|r| r.document_id).collect();
|
|
||||||
let filtered_ids = apply_filters(conn, &all_ids, filters)?;
|
let filtered_ids = apply_filters(conn, &all_ids, filters)?;
|
||||||
let filtered_set: std::collections::HashSet<i64> = filtered_ids.iter().copied().collect();
|
let filtered_set: std::collections::HashSet<i64> = filtered_ids.into_iter().collect();
|
||||||
results
|
ranked
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.filter(|r| filtered_set.contains(&r.document_id))
|
.filter(|r| filtered_set.contains(&r.document_id))
|
||||||
.take(limit)
|
.take(limit)
|
||||||
|
.map(to_hybrid)
|
||||||
.collect()
|
.collect()
|
||||||
} else {
|
} else {
|
||||||
results.into_iter().take(limit).collect()
|
ranked.into_iter().take(limit).map(to_hybrid).collect()
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok((results, warnings))
|
Ok((results, warnings))
|
||||||
|
|||||||
Reference in New Issue
Block a user