Skip to content

Commit

Permalink
[Keyword search] Core node endpoint (#9439)
Browse files Browse the repository at this point in the history
* [Keyword search] Core node endpoint

Description
---
Fixes dust-tt/tasks#1613

Risks
---
na (endpoint not used)

Deploy
---
core

* fix field

* spolu review
  • Loading branch information
philipperolet authored Dec 17, 2024
1 parent 325d738 commit 3a5fcb8
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 2 deletions.
45 changes: 44 additions & 1 deletion core/bin/core_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ use dust::{
providers::provider::{provider, ProviderID},
run,
search_filter::{Filterable, SearchFilter},
search_stores::search_store::{ElasticsearchSearchStore, SearchStore},
search_stores::search_store::{
DatasourceViewFilter, ElasticsearchSearchStore, NodesSearchOptions, SearchStore,
},
sqlite_workers::client::{self, HEARTBEAT_INTERVAL_MS},
stores::{
postgres,
Expand Down Expand Up @@ -3074,6 +3076,44 @@ async fn folders_delete(
}
}

#[derive(serde::Deserialize)]
struct NodesSearchPayload {
query: String,
filter: Vec<DatasourceViewFilter>,
options: Option<NodesSearchOptions>,
}

async fn nodes_search(
State(state): State<Arc<APIState>>,
Json(payload): Json<NodesSearchPayload>,
) -> (StatusCode, Json<APIResponse>) {
let nodes = match state
.search_store
.search_nodes(payload.query, payload.filter, payload.options)
.await
{
Ok(nodes) => nodes,
Err(e) => {
return error_response(
StatusCode::INTERNAL_SERVER_ERROR,
"internal_server_error",
"Failed to search nodes",
Some(e),
);
}
};

(
StatusCode::OK,
Json(APIResponse {
error: None,
response: Some(json!({
"nodes": nodes,
})),
}),
)
}

#[derive(serde::Deserialize)]
struct DatabaseQueryRunPayload {
query: String,
Expand Down Expand Up @@ -3551,6 +3591,9 @@ fn main() {
delete(folders_delete),
)

//Search
.route("/nodes/search", post(nodes_search))

// Misc
.route("/tokenize", post(tokenize))
.route("/tokenize/batch", post(tokenize_batch))
Expand Down
6 changes: 6 additions & 0 deletions core/src/data_sources/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,9 @@ impl Node {
)
}
}

impl From<serde_json::Value> for Node {
fn from(value: serde_json::Value) -> Self {
serde_json::from_value(value).expect("Failed to deserialize Node from JSON value")
}
}
89 changes: 88 additions & 1 deletion core/src/search_stores/search_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,48 @@ use async_trait::async_trait;
use elasticsearch::{
auth::Credentials,
http::transport::{SingleNodeConnectionPool, TransportBuilder},
Elasticsearch, IndexParts,
Elasticsearch, IndexParts, SearchParts,
};
use serde_json::json;
use url::Url;

use crate::data_sources::node::Node;
use crate::{data_sources::data_source::Document, utils};
use tracing::{error, info};

#[derive(serde::Deserialize)]
pub struct NodesSearchOptions {
limit: Option<usize>,
offset: Option<usize>,
}

#[derive(serde::Deserialize)]
pub struct DatasourceViewFilter {
data_source_id: String,
view_filter: Vec<String>,
}

#[async_trait]
pub trait SearchStore {
async fn search_nodes(
&self,
query: String,
filter: Vec<DatasourceViewFilter>,
options: Option<NodesSearchOptions>,
) -> Result<Vec<Node>>;
async fn index_document(&self, document: &Document) -> Result<()>;
fn clone_box(&self) -> Box<dyn SearchStore + Sync + Send>;
}

impl Default for NodesSearchOptions {
fn default() -> Self {
NodesSearchOptions {
limit: Some(10),
offset: Some(0),
}
}
}

impl Clone for Box<dyn SearchStore + Sync + Send> {
fn clone(&self) -> Self {
self.clone_box()
Expand Down Expand Up @@ -80,6 +109,64 @@ impl SearchStore for ElasticsearchSearchStore {
}
}

async fn search_nodes(
&self,
query: String,
filter: Vec<DatasourceViewFilter>,
options: Option<NodesSearchOptions>,
) -> Result<Vec<Node>> {
// First, collect all datasource_ids and their corresponding view_filters
let mut filter_conditions = Vec::new();
for f in filter {
filter_conditions.push(json!({
"bool": {
"must": [
{ "term": { "data_source_id": f.data_source_id } },
{ "terms": { "parents": f.view_filter } }
]
}
}));
}

let options = options.unwrap_or_default();

// then, search
match self
.client
.search(SearchParts::Index(&[NODES_INDEX_NAME]))
.from(options.offset.unwrap_or(0) as i64)
.size(options.limit.unwrap_or(100) as i64)
.body(json!({
"query": {
"bool": {
"must": {
"match": {
"title.edge": query
}
},
"should": filter_conditions,
"minimum_should_match": 1
}
}
}))
.send()
.await
{
Ok(response) => {
// get nodes from elasticsearch response in hits.hits
let response_body = response.json::<serde_json::Value>().await?;
let nodes: Vec<Node> = response_body["hits"]["hits"]
.as_array()
.unwrap()
.iter()
.map(|h| Node::from(h.get("_source").unwrap().clone()))
.collect();
Ok(nodes)
}
Err(e) => Err(e.into()),
}
}

fn clone_box(&self) -> Box<dyn SearchStore + Sync + Send> {
Box::new(self.clone())
}
Expand Down

0 comments on commit 3a5fcb8

Please sign in to comment.