use std::path::Path; use rusqlite::Connection; use wiremock::matchers::{method, path}; use wiremock::{Mock, MockServer, ResponseTemplate}; use crate::core::db::{create_connection, run_migrations}; use crate::core::shutdown::ShutdownSignal; use crate::embedding::chunking::EXPECTED_DIMS; use crate::embedding::ollama::{OllamaClient, OllamaConfig}; use crate::embedding::pipeline::embed_documents_by_ids; const MODEL: &str = "nomic-embed-text"; fn setup_db() -> Connection { let conn = create_connection(Path::new(":memory:")).unwrap(); run_migrations(&conn).unwrap(); conn } fn insert_test_project(conn: &Connection) -> i64 { conn.execute( "INSERT INTO projects (gitlab_project_id, path_with_namespace, web_url) VALUES (1, 'group/test', 'https://gitlab.example.com/group/test')", [], ) .unwrap(); conn.last_insert_rowid() } fn insert_test_document( conn: &Connection, project_id: i64, source_id: i64, content: &str, hash: &str, ) -> i64 { conn.execute( "INSERT INTO documents (source_type, source_id, project_id, content_text, content_hash) VALUES ('issue', ?1, ?2, ?3, ?4)", rusqlite::params![source_id, project_id, content, hash], ) .unwrap(); conn.last_insert_rowid() } fn make_fake_embedding() -> Vec { vec![0.1_f32; EXPECTED_DIMS] } fn make_ollama_response(count: usize) -> serde_json::Value { let embedding = make_fake_embedding(); let embeddings: Vec<_> = (0..count).map(|_| embedding.clone()).collect(); serde_json::json!({ "model": MODEL, "embeddings": embeddings }) } fn count_embeddings_for_doc(conn: &Connection, doc_id: i64) -> i64 { conn.query_row( "SELECT COUNT(*) FROM embedding_metadata WHERE document_id = ?1", [doc_id], |row| row.get(0), ) .unwrap() } fn make_client(base_url: &str) -> OllamaClient { OllamaClient::new(OllamaConfig { base_url: base_url.to_string(), model: MODEL.to_string(), timeout_secs: 10, }) } #[tokio::test] async fn test_embed_by_ids_only_embeds_specified_docs() { let mock_server = MockServer::start().await; Mock::given(method("POST")) .and(path("/api/embed")) .respond_with(ResponseTemplate::new(200).set_body_json(make_ollama_response(1))) .mount(&mock_server) .await; let conn = setup_db(); let proj_id = insert_test_project(&conn); let doc1 = insert_test_document(&conn, proj_id, 1, "Hello world content for doc 1", "hash_a"); let doc2 = insert_test_document(&conn, proj_id, 2, "Hello world content for doc 2", "hash_b"); let signal = ShutdownSignal::new(); let client = make_client(&mock_server.uri()); // Only embed doc1 let result = embed_documents_by_ids(&conn, &client, MODEL, 1, &[doc1], &signal) .await .unwrap(); assert_eq!(result.docs_embedded, 1, "Should embed exactly 1 doc"); assert!(result.chunks_embedded > 0, "Should have embedded chunks"); // doc1 should have embeddings assert!( count_embeddings_for_doc(&conn, doc1) > 0, "doc1 should have embeddings" ); // doc2 should have NO embeddings assert_eq!( count_embeddings_for_doc(&conn, doc2), 0, "doc2 should have no embeddings" ); } #[tokio::test] async fn test_embed_by_ids_skips_already_embedded() { let mock_server = MockServer::start().await; Mock::given(method("POST")) .and(path("/api/embed")) .respond_with(ResponseTemplate::new(200).set_body_json(make_ollama_response(1))) .expect(1) // Should only be called once .mount(&mock_server) .await; let conn = setup_db(); let proj_id = insert_test_project(&conn); let doc1 = insert_test_document(&conn, proj_id, 1, "Hello world content for doc 1", "hash_a"); let signal = ShutdownSignal::new(); let client = make_client(&mock_server.uri()); // First embed let result1 = embed_documents_by_ids(&conn, &client, MODEL, 1, &[doc1], &signal) .await .unwrap(); assert_eq!(result1.docs_embedded, 1); // Second embed with same doc — should skip let result2 = embed_documents_by_ids(&conn, &client, MODEL, 1, &[doc1], &signal) .await .unwrap(); assert_eq!(result2.docs_embedded, 0, "Should embed 0 on second call"); assert_eq!(result2.skipped, 1, "Should report 1 skipped"); assert_eq!(result2.chunks_embedded, 0, "No new chunks"); } #[tokio::test] async fn test_embed_by_ids_empty_input() { let conn = setup_db(); let signal = ShutdownSignal::new(); // Client URL doesn't matter — should never be called let client = make_client("http://localhost:99999"); let result = embed_documents_by_ids(&conn, &client, MODEL, 1, &[], &signal) .await .unwrap(); assert_eq!(result.docs_embedded, 0); assert_eq!(result.chunks_embedded, 0); assert_eq!(result.failed, 0); assert_eq!(result.skipped, 0); } #[tokio::test] async fn test_embed_by_ids_respects_cancellation() { let conn = setup_db(); let proj_id = insert_test_project(&conn); let doc1 = insert_test_document(&conn, proj_id, 1, "Hello world content for doc 1", "hash_a"); let signal = ShutdownSignal::new(); signal.cancel(); // Pre-cancel let client = make_client("http://localhost:99999"); let result = embed_documents_by_ids(&conn, &client, MODEL, 1, &[doc1], &signal) .await .unwrap(); assert_eq!(result.docs_embedded, 0, "Should embed 0 when cancelled"); assert_eq!(result.chunks_embedded, 0, "No chunks when cancelled"); }