Skip to content

Commit

Permalink
Merge pull request #5 from xavidop/xavier/embedders
Browse files Browse the repository at this point in the history
feat: add embedders models
  • Loading branch information
xavidop authored Sep 12, 2024
2 parents d12787f + 769820d commit 685b895
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 6 deletions.
5 changes: 2 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,11 +139,10 @@ For more detailed examples and the explanation of other functionalities, refer t

## Supported models

This plugins supports all currently **Chat/Completition** available models from Github Models.
This plugin supports all currently available **Chat/Completion** and **Embeddings** models from Github Models.

Still in progress:
1. Embedding models
2. Support for image input/output models
1. Support for image input/output models

## API Reference

Expand Down
133 changes: 133 additions & 0 deletions src/github_embedders.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
/**
* Copyright 2024 The Fire Company
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/* eslint-disable @typescript-eslint/no-explicit-any */

import { defineEmbedder, embedderRef } from "@genkit-ai/ai/embedder";
import ModelClient, {
GetEmbeddings200Response,
GetEmbeddingsParameters,
} from "@azure-rest/ai-inference";
import { z } from "zod";
import { type PluginOptions } from "./index.js";
import { AzureKeyCredential } from "@azure/core-auth";

export const TextEmbeddingConfigSchema = z.object({
dimensions: z.number().optional(),
encodingFormat: z.union([z.literal("float"), z.literal("base64")]).optional(),
});

export type TextEmbeddingGeckoConfig = z.infer<
typeof TextEmbeddingConfigSchema
>;

export const TextEmbeddingInputSchema = z.string();

export const openAITextEmbedding3Small = embedderRef({
name: "github/text-embedding-3-small",
configSchema: TextEmbeddingConfigSchema,
info: {
dimensions: 1536,
label: "OpenAI - Text-embedding-3-small",
supports: {
input: ["text"],
},
},
});

export const openAITextEmbedding3Large = embedderRef({
name: "github/text-embedding-3-large",
configSchema: TextEmbeddingConfigSchema,
info: {
dimensions: 3072,
label: "OpenAI - Text-embedding-3-large",
supports: {
input: ["text"],
},
},
});

export const cohereEmbedv3English = embedderRef({
name: "github/cohere-embed-v3-english",
configSchema: TextEmbeddingConfigSchema,
info: {
dimensions: 1024,
label: "Cohere - Embed-embed-v3-english",
supports: {
input: ["text"],
},
},
});

export const cohereEmbedv3Multilingual = embedderRef({
name: "github/cohere-embed-v3-multilingual",
configSchema: TextEmbeddingConfigSchema,
info: {
dimensions: 1024,
label: "Cohere - Embed-embed-v3-multilingual",
supports: {
input: ["text"],
},
},
});

export const SUPPORTED_EMBEDDING_MODELS: Record<string, any> = {
"text-embedding-3-small": openAITextEmbedding3Small,
"text-embedding-3-large": openAITextEmbedding3Large,
"cohere-embed-v3-english": cohereEmbedv3English,
"cohere-embed-v3-multilingual": cohereEmbedv3Multilingual,
};

export function githubEmbedder(name: string, options?: PluginOptions) {
const token = options?.githubToken || process.env.GITHUB_TOKEN;
let endpoint = options?.endpoint || process.env.GITHUB_ENDPOINT;
if (!token) {
throw new Error(
"Please pass in the TOKEN key or set the GITHUB_TOKEN environment variable",
);
}
if (!endpoint) {
endpoint = "https://models.inference.ai.azure.com";
}

const client = ModelClient(endpoint, new AzureKeyCredential(token));
const model = SUPPORTED_EMBEDDING_MODELS[name];

return defineEmbedder(
{
info: model.info!,
configSchema: TextEmbeddingConfigSchema,
name: model.name,
},
async (input, options) => {
const body = {
body: {
model: name,
input: input.map((d) => d.text()),
dimensions: options?.dimensions,
encoding_format: options?.encodingFormat,
},
} as GetEmbeddingsParameters;
const embeddings = (await client
.path("/embeddings")
.post(body)) as GetEmbeddings200Response;
return {
embeddings: embeddings.body.data.map((d) => ({
embedding: d.embedding,
})),
};
},
);
}
6 changes: 4 additions & 2 deletions src/github_llms.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
*/
/* eslint-disable @typescript-eslint/no-explicit-any */


import { Message } from "@genkit-ai/ai";
import {
CandidateData,
Expand Down Expand Up @@ -653,7 +652,10 @@ export function toGithubRequestBody(
} as any;

for (const key in body.body) {
if (!body.body[key] || (Array.isArray(body.body[key]) && !body.body[key].length))
if (
!body.body[key] ||
(Array.isArray(body.body[key]) && !body.body[key].length)
)
delete body.body[key];
}
return body;
Expand Down
19 changes: 18 additions & 1 deletion src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,14 @@ import {
microsoftPhi35Mini128kInstruct,
SUPPORTED_GITHUB_MODELS,
} from "./github_llms.js";
import {
cohereEmbedv3English,
cohereEmbedv3Multilingual,
githubEmbedder,
openAITextEmbedding3Large,
openAITextEmbedding3Small,
SUPPORTED_EMBEDDING_MODELS,
} from "./github_embedders.js";

export {
openAIGpt4o,
Expand All @@ -51,6 +59,13 @@ export {
microsoftPhi35Mini128kInstruct,
};

export {
openAITextEmbedding3Small,
openAITextEmbedding3Large,
cohereEmbedv3English,
cohereEmbedv3Multilingual,
};

export interface PluginOptions {
githubToken?: string;
endpoint?: string;
Expand Down Expand Up @@ -78,7 +93,9 @@ export const github: Plugin<[PluginOptions]> = genkitPlugin(
githubModel(name, client),
),
],
embedders: [],
embedders: Object.keys(SUPPORTED_EMBEDDING_MODELS).map((name) =>
githubEmbedder(name, options),
),
};
},
);
Expand Down

0 comments on commit 685b895

Please sign in to comment.