Skip to content

Commit

Permalink
feat: add mock ai
Browse files Browse the repository at this point in the history
  • Loading branch information
kirklin committed Feb 25, 2024
1 parent 4693157 commit a11b3d0
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 12 deletions.
1 change: 0 additions & 1 deletion apps/admin/src/pages/chat/components/ChatLayout.vue
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ watchEffect(() => {
:width="360"
:native-scrollbar="true"
show-trigger="arrow-circle"
bordered
>
<ChatHistorySidebar />
</NLayoutSider>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ defineModel();
</script>

<template>
<NInput type="textarea" placeholder="请输入聊天内容。。。。" round show-count :resizable="false" class="h-full" :bordered="false" />
<NInput type="textarea" placeholder="请输入聊天内容。。。。" round show-count :resizable="false" class="h-full bg-[--action-color]" />
</template>

<style scoped>
Expand Down
6 changes: 2 additions & 4 deletions packages/ai/core/src/languageModelAgent/azureOpenai/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,10 @@ import {
} from "@azure/openai";
import { OpenAIStream, StreamingTextResponse } from "ai";

import { LanguageModelAgentRuntimeErrorType } from "../utils/error/constants";
import { AgentRuntimeError, DEBUG_CHAT_COMPLETION, LanguageModelAgentRuntimeErrorType, debugStream } from "../utils";
import type { OpenAIChatStreamPayload } from "../../types";
import { ModelBrandProvider } from "../types";
import { AgentRuntimeError } from "../utils/error/createError";
import { debugStream } from "../utils/debugStream";
import { DEBUG_CHAT_COMPLETION } from "../utils/env";

import { AbstractLanguageModel } from "../abstractAI";

export class CelerisAzureOpenAI extends AbstractLanguageModel<any> {
Expand Down
58 changes: 58 additions & 0 deletions packages/ai/core/src/languageModelAgent/celeris/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import { StreamingTextResponse } from "ai";
import { AbstractLanguageModel } from "../abstractAI";
import { ModelBrandProvider } from "../types";
import { AgentRuntimeError, DEBUG_CHAT_COMPLETION, LanguageModelAgentRuntimeErrorType, debugStream } from "../utils";
import type { OpenAIChatMessage, OpenAIChatStreamPayload } from "../../types";

export class CelerisMockLanguageModel extends AbstractLanguageModel<any> {
constructor() {
super(ModelBrandProvider.Celeris);
}

async streamChatCompletions(messages: OpenAIChatMessage[]) {
// 模拟将响应转换为流
const stream = new ReadableStream({
start(controller) {
messages.map(message =>
// 模拟处理消息
controller.enqueue(`Mock completion for ${message.content}`));
controller.close();
},
});

return stream;
}

async chat(payload: OpenAIChatStreamPayload) {
// ============ 1. preprocess messages ============ //
const { messages } = payload;

// ============ 2. send api ============ //

try {
const response = await this.streamChatCompletions(
messages,
);

const [debug, prod] = response.tee();

if (DEBUG_CHAT_COMPLETION) {
debugStream(debug).catch(console.error);
}

return new StreamingTextResponse(prod);
} catch (e) {
const error = e as { [key: string]: any; code: string; message: string };

const errorType = error.code
? LanguageModelAgentRuntimeErrorType.MockAIBusinessError
: LanguageModelAgentRuntimeErrorType.AgentRuntimeError;

throw AgentRuntimeError.chat({
error,
errorType,
provider: this._brandId,
});
}
}
}
8 changes: 3 additions & 5 deletions packages/ai/core/src/languageModelAgent/openai/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@ import type { ClientOptions } from "openai";
import urlJoin from "url-join";
import { OpenAIStream, StreamingTextResponse } from "ai";
import { AbstractLanguageModel } from "../abstractAI";
import { AgentRuntimeError } from "../utils/error/createError";
import { LanguageModelAgentRuntimeErrorType } from "../utils/error/constants";
import { AgentRuntimeError, DEBUG_CHAT_COMPLETION, LanguageModelAgentRuntimeErrorType, debugStream, handleOpenAIError } from "../utils";

import type { OpenAIChatStreamPayload } from "../../types";
import { DEBUG_CHAT_COMPLETION } from "../utils/env";
import { debugStream } from "../utils/debugStream";
import { handleOpenAIError } from "../utils/error/handleOpenAIError";

import { desensitizeUrl } from "../utils/desensitizeUrl";
import { ModelBrandProvider } from "../types";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ export const LanguageModelAgentRuntimeErrorType = {
AgentRuntimeError: "AgentRuntimeError", // Agent Runtime 模块运行时错误
LocationNotSupportedError: "LocationNotSupportedError", // 不支持的位置错误
OpenAIBusinessError: "OpenAIBusinessError", // OpenAI 业务错误
MockAIBusinessError: "MockAIBusinessError", // Mock 业务错误
NoOpenAIAPIKey: "NoOpenAIAPIKey", // 缺少 OpenAI API 密钥
InvalidAzureAPIKey: "InvalidAzureAPIKey", // 无效的 Azure API 密钥
AzureBusinessError: "AzureBusinessError", // Azure 业务错误
Expand Down
11 changes: 10 additions & 1 deletion packages/ai/services/openai/agent/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import type {
AbstractLanguageModel,
OpenAIChatStreamPayload,
} from "@celeris/ai-core";
import { CelerisMockLanguageModel } from "@celeris/ai-core/src/languageModelAgent/celeris";
import { getServerConfig } from "../config/server";
import type { JWTPayload } from "../constants/auth";

Expand Down Expand Up @@ -37,7 +38,6 @@ class LanguageModelAgent {
switch (provider) {
// eslint-disable-next-line default-case-last
default:
case "oneapi":
case ModelBrandProvider.OpenAI: {
runtimeModel = this.initOpenAI(payload, azureOpenAI);
break;
Expand All @@ -47,6 +47,11 @@ class LanguageModelAgent {
runtimeModel = this.initAzureOpenAI(payload);
break;
}

case ModelBrandProvider.Celeris:{
runtimeModel = this.initMockAI();
break;
}
}

return new LanguageModelAgent(runtimeModel);
Expand All @@ -73,6 +78,10 @@ class LanguageModelAgent {
});
}

private static initMockAI() {
return new CelerisMockLanguageModel();
}

private static initAzureOpenAI(payload: JWTPayload) {
const { AZURE_API_KEY, AZURE_API_VERSION, AZURE_ENDPOINT } = getServerConfig();
const apiKey = payload?.apiKey || AZURE_API_KEY;
Expand Down

0 comments on commit a11b3d0

Please sign in to comment.