Skip to content

Commit

Permalink
test es knn score (#1047)
Browse files Browse the repository at this point in the history
* rescale es knn score

* test es knn score

---------

Co-authored-by: Andrea Corradi <andrea.corradi@xayn.com>
  • Loading branch information
janpetschexain and acrrd authored Aug 2, 2023
1 parent e4f4e77 commit 1c4535f
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 4 deletions.
8 changes: 8 additions & 0 deletions web-api-db-ctrl/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,17 @@ impl Silo {
&self.postgres_config
}

pub fn postgres_client(&self) -> &PgClient {
&self.postgres
}

pub fn elastic_config(&self) -> &EsConfig {
&self.elastic_config
}

pub fn elastic_client(&self) -> &EsClient {
&self.elastic
}
}

#[derive(Deserialize, Debug)]
Expand Down
6 changes: 2 additions & 4 deletions web-api/src/storage/elastic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,15 +153,13 @@ impl Client {
let response = self
.bulk_request(documents.flat_map(|document| {
[
serde_json::to_value(BulkInstruction::Index { id: &document.id })
.map_err(Into::into),
serde_json::to_value(BulkInstruction::Index { id: &document.id }),
serde_json::to_value(IngestedDocument {
snippet: &document.snippet,
properties: &document.properties,
embedding: &document.embedding,
tags: &document.tags,
})
.map_err(Into::into),
}),
]
}))
.await?;
Expand Down
76 changes: 76 additions & 0 deletions web-api/tests/elastic.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
// Copyright 2023 Xayn AG
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as
// published by the Free Software Foundation, version 3.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.

use serde_json::{json, Error, Value};
use xayn_integration_tests::{test_app, TEST_EMBEDDING_SIZE, UNCHANGED_CONFIG};
use xayn_test_utils::assert_approx_eq;
use xayn_web_api::Ingestion;
use xayn_web_api_shared::{
elastic::{BulkInstruction, SerdeDiscard},
serde::json_object,
};

fn id(id: &str) -> Result<Value, Error> {
serde_json::to_value(BulkInstruction::Index { id: &id })
}

fn emb(emb: &[f32]) -> Result<Value, Error> {
Ok(json!({ "embedding": emb }))
}

// just to be sure that the behavior hasn't changed
#[test]
fn test_normalized_es_knn_scores() {
test_app::<Ingestion, _>(UNCHANGED_CONFIG, |_, _, services| async move {
let client = services
.silo
.elastic_client()
.with_index(&services.tenant.tenant_id);
const LEN: usize = TEST_EMBEDDING_SIZE / 2;
let normalized = (LEN as f32).sqrt().recip();
let embedding = [[normalized; LEN], [0.; LEN]].concat();

let response = client
.bulk_request::<SerdeDiscard>([
id("d1"),
emb(&embedding),
id("d2"),
emb(&[[0.; LEN], [normalized; LEN]].concat()),
id("d3"),
emb(&[[-normalized; LEN], [0.; LEN]].concat()),
])
.await
.unwrap();
assert!(!response.errors);

let scores = client
.search_request::<String>(json_object!({
"knn": {
"field": "embedding",
"query_vector": embedding,
"k": 5,
"num_candidates": 5,
},
"size": 5
}))
.await
.unwrap();
assert_eq!(scores.len(), 3);
assert_approx_eq!(f32, scores["d1"], 1.);
assert_approx_eq!(f32, scores["d2"], 0.5);
assert_approx_eq!(f32, scores["d3"], 0., epsilon = 1e-7);

Ok(())
});
}

0 comments on commit 1c4535f

Please sign in to comment.