Skip to content

Commit

Permalink
Add local client to be used with dria-cli (#4)
Browse files Browse the repository at this point in the history
* add local client to be used with dria-cli

* depracate `query`

* bump version
  • Loading branch information
erhant authored Feb 26, 2024
1 parent 32dbfca commit d0adfc0
Show file tree
Hide file tree
Showing 9 changed files with 230 additions and 69 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ const dria = new Dria({ apiKey });

contractId = await dria.create(
"My New Contract,
"jinaai/jina-embeddings-v2-base-en",
"jina-embeddings-v2-base-en",
"Science",
);
dria.contractId = contractId;
Expand Down
3 changes: 2 additions & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "dria",
"version": "0.0.3",
"version": "0.0.4",
"license": "Apache-2.0",
"author": "FirstBatch Team <dev@firstbatch.xyz>",
"contributors": [
Expand All @@ -23,6 +23,7 @@
"lint": "eslint '**/*.ts' && echo 'All good.'",
"test": "bun test --timeout 15000",
"t": "bun run test",
"test:local": "LOCAL_TEST=true bun test local --timeout 15000",
"proto:code": "npx pbjs ./proto/insert.proto -w commonjs -t static-module -o ./proto/insert.js",
"proto:type": "npx pbts ./proto/insert.js -o ./proto/insert.d.ts",
"proto": "bun proto:code && bun proto:type"
Expand Down
39 changes: 39 additions & 0 deletions src/clients/common.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import { AxiosInstance } from "axios";

/**
* A utility class that exposes `post` and `get` requests
* for other clients to use. The constructor takes in an Axios instance.
*/
export class DriaCommon {
constructor(protected readonly client: AxiosInstance) {}

/**
* A POST request wrapper.
* @param url request URL
* @param body request body
* @template T type of response body
* @returns parsed response body
*/
protected async post<T = unknown>(url: string, body: unknown) {
const res = await this.client.post<{ success: boolean; data: T; code: number }>(url, body);
if (res.status !== 200) {
throw `Dria API (POST) failed with ${res.statusText} (${res.status}).\n${res.data}`;
}
return res.data.data;
}

/**
* A GET request wrapper.
* @param url request URL
* @param params query parameters
* @template T type of response body
* @returns parsed response body
*/
protected async get<T = unknown>(url: string, params: Record<string, unknown> = {}) {
const res = await this.client.get<{ success: boolean; data: T; code: number }>(url, { params });
if (res.status !== 200) {
throw `Dria API (GET) failed with ${res.statusText} (${res.status}).\n${res.data}`;
}
return res.data.data;
}
}
89 changes: 29 additions & 60 deletions src/dria.ts → src/clients/dria.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import Axios from "axios";
import type { AxiosInstance } from "axios";
import { encodeBatchTexts, encodeBatchVectors } from "./proto";
import { SearchOptions, QueryOptions, BatchVectors, BatchTexts, MetadataType } from "./schemas";
import { CategoryTypes, DriaParams, ModelTypes } from "./types";
import constants from "./constants";
import { encodeBatchTexts, encodeBatchVectors } from "../proto";
import { SearchOptions, QueryOptions, BatchVectors, BatchTexts, MetadataType } from "../schemas";
import { CategoryTypes, DriaParams, ModelTypes } from "../types";
import constants from "../constants";
import { DriaCommon } from "./common";

/**
* Dria JS Client
* ## Dria Client
*
* @param params optional API key and contract ID.
*
Expand Down Expand Up @@ -40,8 +40,7 @@ import constants from "./constants";
* dria.contractId = contractId;
*/
// eslint-disable-next-line @typescript-eslint/no-explicit-any
export class Dria<T extends MetadataType = any> {
protected client: AxiosInstance;
export class Dria<T extends MetadataType = any> extends DriaCommon {
contractId: string | undefined;
/** Cached contract models. */
private models: Record<string, ModelTypes> = {};
Expand All @@ -50,18 +49,21 @@ export class Dria<T extends MetadataType = any> {
const apiKey = params.apiKey ?? process.env.DRIA_API_KEY;
if (!apiKey) throw new Error("Missing Dria API key.");

super(
Axios.create({
headers: {
"x-api-key": apiKey,
"Content-Type": "application/json",
"Accept-Encoding": "gzip, deflate, br",
Connection: "keep-alive",
Accept: "*/*",
},
// lets us handle the errors
validateStatus: () => true,
}),
);

this.contractId = params.contractId;
this.client = Axios.create({
headers: {
"x-api-key": apiKey,
"Content-Type": "application/json",
"Accept-Encoding": "gzip, deflate, br",
Connection: "keep-alive",
Accept: "*/*",
},
// lets us handle the errors
validateStatus: () => true,
});
}

/** A text-based search.
Expand All @@ -79,7 +81,7 @@ export class Dria<T extends MetadataType = any> {
async search(text: string, options: SearchOptions = {}) {
options = SearchOptions.parse(options);
const contractId = this.getContractId();
return await this.post<{ id: number; metadata: string; score: number }[]>(constants.DRIA_SEARCH_URL + "/search", {
return await this.post<{ id: number; metadata: string; score: number }[]>(constants.DRIA.SEARCH_URL + "/search", {
query: text,
top_n: options.topK,
level: options.level,
Expand All @@ -103,7 +105,7 @@ export class Dria<T extends MetadataType = any> {
async query<M extends MetadataType = T>(vector: number[], options: QueryOptions = {}) {
options = QueryOptions.parse(options);
const data = await this.post<{ id: number; metadata: string; score: number }[]>(
constants.DRIA_SEARCH_URL + "/query",
constants.DRIA.SEARCH_URL + "/query",
{ vector, contract_id: this.getContractId(), top_n: options.topK },
);
return data.map((d) => ({ ...d, metadata: JSON.parse(d.metadata) as M }));
Expand All @@ -119,7 +121,7 @@ export class Dria<T extends MetadataType = any> {
*/
async fetch<M extends MetadataType = T>(ids: number[]) {
if (ids.length === 0) throw "No IDs provided.";
const data = await this.post<{ metadata: string[]; vectors: number[][] }>(constants.DRIA_SEARCH_URL + "/fetch", {
const data = await this.post<{ metadata: string[]; vectors: number[][] }>(constants.DRIA.SEARCH_URL + "/fetch", {
id: ids,
contract_id: this.getContractId(),
});
Expand All @@ -145,7 +147,7 @@ export class Dria<T extends MetadataType = any> {
items = BatchVectors.parse(items) as BatchVectors<M>;
const encodedData = encodeBatchVectors(items);
const contractId = this.getContractId();
const data = await this.post<string>(constants.DRIA_INSERT_URL + "/insert_vector", {
const data = await this.post<string>(constants.DRIA.INSERT_URL + "/insert_vector", {
data: encodedData,
batch_size: items.length,
model: await this.getModel(contractId),
Expand All @@ -170,7 +172,7 @@ export class Dria<T extends MetadataType = any> {
items = BatchTexts.parse(items) as BatchTexts<M>;
const encodedData = encodeBatchTexts(items);
const contractId = this.getContractId();
const data = await this.post<string>(constants.DRIA_INSERT_URL + "/insert_text", {
const data = await this.post<string>(constants.DRIA.INSERT_URL + "/insert_text", {
data: encodedData,
batch_size: items.length,
model: await this.getModel(contractId),
Expand All @@ -196,7 +198,7 @@ export class Dria<T extends MetadataType = any> {
* // you can now make queries, or insert data there
*/
async create(name: string, embedding: ModelTypes, category: CategoryTypes, description: string = "") {
const data = await this.post<{ contract_id: string }>(constants.DRIA_API_URL + "/v1/knowledge/index/create", {
const data = await this.post<{ contract_id: string }>(constants.DRIA.API_URL + "/v1/knowledge/index/create", {
name,
embedding,
category,
Expand All @@ -214,7 +216,7 @@ export class Dria<T extends MetadataType = any> {
*/
async delete(contractId: string) {
// expect message to be `true`
const data = await this.post<{ message: boolean }>(constants.DRIA_API_URL + "/v1/knowledge/remove", {
const data = await this.post<{ message: boolean }>(constants.DRIA.API_URL + "/v1/knowledge/remove", {
contract_id: contractId,
});
return data.message;
Expand All @@ -231,7 +233,7 @@ export class Dria<T extends MetadataType = any> {
if (contractId in this.models) {
return this.models[contractId];
} else {
const data = await this.get<{ model: string }>(constants.DRIA_API_URL + "/v1/knowledge/index/get_model", {
const data = await this.get<{ model: string }>(constants.DRIA.API_URL + "/v1/knowledge/index/get_model", {
contract_id: contractId,
});
// memoize the model for later
Expand All @@ -247,37 +249,4 @@ export class Dria<T extends MetadataType = any> {
if (this.contractId) return this.contractId;
throw Error("ContractID was not set.");
}

/**
* A POST request wrapper.
* @param url request URL
* @param body request body
* @template T type of response body
* @returns parsed response body
*/
private async post<T = unknown>(url: string, body: unknown) {
const res = await this.client.post<{ success: boolean; data: T; code: number }>(url, body);
if (res.status !== 200) {
console.log({ url, body });
// console.log(res);
throw `Dria API (POST) failed with ${res.statusText} (${res.status}).\n${res.data}`;
}
return res.data.data;
}

/**
* A GET request wrapper.
* @param url request URL
* @param params query parameters
* @template T type of response body
* @returns parsed response body
*/
private async get<T = unknown>(url: string, params: Record<string, unknown> = {}) {
const res = await this.client.get<{ success: boolean; data: T; code: number }>(url, { params });
if (res.status !== 200) {
console.log(res.request);
throw `Dria API (GET) failed with ${res.statusText} (${res.status}).\n${res.data}`;
}
return res.data.data;
}
}
2 changes: 2 additions & 0 deletions src/clients/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
export * from "./dria";
export * from "./local";
108 changes: 108 additions & 0 deletions src/clients/local.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import Axios from "axios";
import { QueryOptions, BatchVectors, MetadataType } from "../schemas";
import { DriaCommon } from "./common";

/**
* ## Dria Local Client
*
* Dria local client is a convenience tool that allows one to use the served knowledge via [Dria Docker](https://github.com/firstbatchxyz/dria-docker).
* The URL defaults to `http://localhost:8080`, but you can override it.
*
* Unlike the other Dria client, Dria local does not require an API key or a contract ID, since the locally served knowledge serves a single contract.
* Furthermore, text-based input is not allowed as that requires an embedding model to be running on the side.
*
* @template T default type of metadata; a metadata in Dria is a single-level mapping, with string keys and values of type `string`, `number`
*
* @example
* // connects to http://localhost:8080
* const dria = new DriaLocal();
*
* @example
* const dria = new DriaLocal("your-url");
*/
// eslint-disable-next-line @typescript-eslint/no-explicit-any
export class DriaLocal<T extends MetadataType = any> extends DriaCommon {
public url: string;
constructor(url: string = "http://localhost:8080") {
super(
Axios.create({
baseURL: url,
headers: {
"Content-Type": "application/json",
Connection: "keep-alive",
Accept: "*/*",
},
// lets us handle the errors
validateStatus: () => true,
}),
);
this.url = url;
}

/** A simple health-check. */
async health() {
try {
await this.get("/health");
return true;
} catch {
return false;
}
}

/** A vector-based query.
* @param vector query vector.
* @param options query options:
* - `topK`: number of results to return.
* @template M type of the metadata, defaults to type provided to the client.
* @returns an array of `topK` results with id, metadata and the relevancy score.
* @example
* const res = await dria.query<{about: string}>([0.1, 0.92, ..., 0.16]);
* console.log(res[0].metadata.about);
*
* @deprecated local query is disabled right now
*/
private async query<M extends MetadataType = T>(vector: number[], options: QueryOptions = {}) {
options = QueryOptions.parse(options);
const data = await this.post<{ id: number; metadata: M; score: number }[]>("/query", {
vector,
top_n: options.topK,
});
return data;
}

/** Fetch vectors with the given IDs.
* @param ids an array of ids.
* @template M type of the metadata, defaults to type provided to the client.
* @returns an array of metadatas belonging to the given vector IDs.
* @example
* const res = await dria.fetch([0])
* console.log(res[0])
*/
async fetch<M extends MetadataType = T>(ids: number[]) {
if (ids.length === 0) throw "No IDs provided.";
const data = await this.post<M[]>("/fetch", {
id: ids,
});
return data;
}

/**
* Insert a batch of vectors to your existing knowledge.
* @param items batch of vectors with optional metadatas
* @returns a string indicating success
* @example
* const batch = [
* {vector: [...], metadata: {}},
* {vector: [...], metadata: {foo: 'bar'}},
* // ...
* ]
* await dria.insertVectors(batch);
*/
async insertVectors<M extends MetadataType = T>(items: BatchVectors<M>) {
items = BatchVectors.parse(items) as BatchVectors<M>;
const data = await this.post<string>("/insert_vector", {
data: items,
});
return data;
}
}
15 changes: 9 additions & 6 deletions src/constants/index.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
export default {
/** URL to make fetch / query / search requests */
DRIA_SEARCH_URL: "https://search.dria.co/hnsw",
/** URL to insert texts and vectors */
DRIA_INSERT_URL: "https://search.dria.co/hnswt",
/** URL to get model */
DRIA_API_URL: "https://api.dria.co",
// TODO: naming doesnt really make sense here...
DRIA: {
/** URL to make fetch / query / search requests */
SEARCH_URL: "https://search.dria.co/hnsw",
/** URL to insert texts and vectors */
INSERT_URL: "https://search.dria.co/hnswt",
/** URL to get model */
API_URL: "https://api.dria.co",
},
} as const;
2 changes: 1 addition & 1 deletion src/index.ts
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
export { Dria } from "./dria";
export { Dria, DriaLocal } from "./clients";
export type { DriaParams } from "./types";
Loading

0 comments on commit d0adfc0

Please sign in to comment.