Skip to content

Commit

Permalink
feat(llms): finish stream call
Browse files Browse the repository at this point in the history
  • Loading branch information
sd0xdev committed Aug 27, 2023
1 parent d0196fe commit 7fb8de9
Show file tree
Hide file tree
Showing 13 changed files with 187 additions and 80 deletions.
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import { Test, TestingModule } from '@nestjs/testing';
import { LLMAIController } from './llm-ai.controller';
import { AsgardLoggerSupplement } from '@asgard-hub/nest-winston';
import { AuthGuard } from '../../auth/auth.guard';
import { LangChainService } from '../../llm-ai/lang-chain/lang-chain.service';
import { LLMAIService } from '../../llm-ai/service/llmai/llmai.service';

describe('LlmAiController', () => {
let controller: LLMAIController;
Expand All @@ -11,12 +10,10 @@ describe('LlmAiController', () => {
const module: TestingModule = await Test.createTestingModule({
providers: [
{
provide: AsgardLoggerSupplement.LOGGER_HELPER_SERVICE,
useValue: {},
},
{
provide: LangChainService,
useValue: {},
provide: LLMAIService,
useValue: {
chat: jest.fn(),
},
},
],
controllers: [LLMAIController],
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import { Controller, Inject, UseGuards } from '@nestjs/common';
import { LangChainService } from '../../llm-ai/lang-chain/lang-chain.service';
import { AuthGuard } from '../../auth/auth.guard';
import { FunctionSelection } from '@asgard-hub/utils';
import { TranslateToEnglishPrompt } from '../../llm-ai/prompt/translate-to-english.prompt';
import { GrpcMethod } from '@nestjs/microservices';
import { AsgardLoggerSupplement, AsgardLogger } from '@asgard-hub/nest-winston';
import { GrpcMethod, GrpcStreamMethod } from '@nestjs/microservices';
import {
GrpcResponse,
LLMAIService,
} from '../../llm-ai/service/llmai/llmai.service';
import { Observable, Subject, switchMap } from 'rxjs';

export interface ChatOptions {
userInput: string;
Expand All @@ -14,43 +16,26 @@ export interface ChatOptions {
@UseGuards(AuthGuard)
@Controller('llm-ai')
export class LLMAIController {
@Inject(AsgardLoggerSupplement.LOGGER_HELPER_SERVICE)
private readonly asgardLogger: AsgardLogger;
@Inject()
private readonly langChainService: LangChainService;
private readonly llmAIService: LLMAIService;

@GrpcMethod('LLMAIService')
async chat(data: ChatOptions) {
const type = this.promptSelect(data.key);
let response = '';
if (!type) {
response = await this.langChainService.getGeneralChatResponse(
data.userInput
);
} else {
response = (
await this.langChainService.getFunctionChainResponse(
data.userInput,
type
)
).message;
}
chat(data: ChatOptions) {
return this.llmAIService.chat(data);
}

this.asgardLogger.debug(response);
@GrpcStreamMethod('LLMAIService', 'ChatStream')
chatStream(data: Observable<ChatOptions>) {
const subject = new Subject<GrpcResponse>();

return {
response,
};
}
data
.pipe(
switchMap((data) => {
return this.llmAIService.chat(data, subject);
})
)
.subscribe();

promptSelect(key: FunctionSelection) {
switch (key) {
case FunctionSelection.generalChat:
return undefined;
case FunctionSelection.anyTranslateToEnglish:
return TranslateToEnglishPrompt;
default:
return undefined;
}
return subject.asObservable();
}
}
7 changes: 4 additions & 3 deletions apps/yggdrasil-core-engine/src/app/llm-ai/llm-ai.module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@ import { IAppConfig, ConfigPath } from '../config/app.config';
import { isDev, isStaging, isProd } from '../constants/common.constant';
import { EventEmitterModule } from '@nestjs/event-emitter';
import { LLMAIController } from '../controllers/llm-ai/llm-ai.controller';
import { LangChainService } from './lang-chain/lang-chain.service';
import { LangChainService } from './service/langchain/langchain.service';
import { NestLangChainModule } from '@sd0x/nest-langchain';
import { OpenAIProvider } from '../provider/open-ai/open-ai';
import { ProviderModule } from '../provider/provider.module';
import { LLMAIService } from './service/llmai/llmai.service';

@Module({
imports: [
Expand Down Expand Up @@ -43,8 +44,8 @@ import { ProviderModule } from '../provider/provider.module';
MongoModule,
DataSourceAdapterModule,
],
providers: [LangChainService],
exports: [LangChainService],
providers: [LangChainService, LLMAIService],
exports: [LLMAIService],
controllers: [LLMAIController],
})
export class LLMAIModule {}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { Test, TestingModule } from '@nestjs/testing';
import { LangChainService } from './lang-chain.service';
import { LangChainService } from './langchain.service';
import { NestLangchainOptionsSupplement } from '@sd0x/nest-langchain';
import { AsgardLoggerSupplement } from '@asgard-hub/nest-winston';

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,24 @@ export class LangChainService {
userInput: string,
langChainPromptType: { new (): BaseLangChainPrompt<T> }
) {
const result = await this.nestLangchainService.getFunctionChainResponse<T>({
return this.nestLangchainService.getFunctionChainResponse<T>({
input: {
userInput,
},
langChainPromptType,
});

this.asgardLogger.debug(result);

return result;
}

async getGeneralChatResponse(userInput: string) {
const message = await this.nestLangchainService.getGeneralChatResponse(
userInput
);

return message.content;
return {
metaOutput: {
...message,
},
response: message.content,
};
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import { Test, TestingModule } from '@nestjs/testing';
import { LLMAIService } from './llmai.service';
import { AsgardLoggerSupplement } from '@asgard-hub/nest-winston';
import { LangChainService } from '../langchain/langchain.service';

describe('LlmaiService', () => {
let service: LLMAIService;

beforeEach(async () => {
const module: TestingModule = await Test.createTestingModule({
providers: [
{
provide: AsgardLoggerSupplement.LOGGER_HELPER_SERVICE,
useValue: {},
},
{
provide: LangChainService,
useValue: {},
},
LLMAIService,
],
}).compile();

service = module.get<LLMAIService>(LLMAIService);
});

it('should be defined', () => {
expect(service).toBeDefined();
});
});
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import { Inject, Injectable } from '@nestjs/common';
import { FunctionSelection } from '@asgard-hub/utils';
import { AsgardLogger, AsgardLoggerSupplement } from '@asgard-hub/nest-winston';
import { LangChainService } from '../langchain/langchain.service';
import { ChatOptions } from '../../../controllers/llm-ai/llm-ai.controller';
import { TranslateToEnglishPrompt } from '../../prompt/translate-to-english.prompt';
import { Subject, of, switchMap } from 'rxjs';
import { randomUUID } from 'crypto';

export interface GrpcResponse<T = unknown> {
id: string;
timestamp: number;
uuid: string;
event: string;
data: string | T;
}

@Injectable()
export class LLMAIService {
@Inject(AsgardLoggerSupplement.LOGGER_HELPER_SERVICE)
private readonly asgardLogger: AsgardLogger;
@Inject()
private readonly langChainService: LangChainService;
chat(data: ChatOptions, subject?: Subject<GrpcResponse>) {
return of(this.promptSelect(data.key)).pipe(
switchMap((prompt) => {
subject?.next({
id: 'prompt',
timestamp: Date.now(),
uuid: randomUUID(),
event: 'notification',
data: prompt?.name ?? FunctionSelection.generalChat,
});
return of(prompt);
}),
switchMap(async (type) => {
if (!type) {
return await this.langChainService.getGeneralChatResponse(
data.userInput
);
} else {
return await this.langChainService.getFunctionChainResponse(
data.userInput,
type
);
}
}),
switchMap((response) => {
subject?.next({
id: 'response',
timestamp: Date.now(),
uuid: randomUUID(),
event: 'message',
data: JSON.stringify(response),
});
this.asgardLogger.debug(response);
return of(subject.complete());
})
);
}

private promptSelect(key: FunctionSelection) {
switch (key) {
case FunctionSelection.generalChat:
return undefined;
case FunctionSelection.anyTranslateToEnglish:
return TranslateToEnglishPrompt;
default:
return undefined;
}
}
}
10 changes: 7 additions & 3 deletions apps/yggdrasil-core-engine/src/assets/llmai/proto/llmai.proto
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package llmai;

service LLMAIService {
rpc Chat (ChatOptions) returns (ChatCompletionResponse);
rpc ChatStream (stream ChatOptions) returns (stream ChatCompletionResponse);
}

message ChatOptions {
Expand All @@ -12,6 +13,9 @@ message ChatOptions {
}

message ChatCompletionResponse {
string response = 1;
}

string id = 1;
int32 timestamp = 2;
string uuid = 3;
string event = 4;
string data = 5;
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ export const coreEngineLLMAIConfig = registerAs(
);

function getProtoPath(fileName: string) {
return isDev
? resolve(__dirname, 'assets', fileName)
return process.env['IS_DEBUG']
? resolve(__dirname, '..', '..', 'assets', fileName)
: resolve(__dirname, '.', 'assets', fileName);
}
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,21 @@ export interface ChatOptions {
}

export interface ChatCompletionResponse {
response: string;
id: string;
timestamp: number;
event: string;
uuid: string;
data: string;
}

export interface LLMAIService {
chat(
options: ChatOptions,
metadata?: Metadata
): Observable<ChatCompletionResponse>;

chatStream(
options: Observable<ChatOptions>,
metadata?: Metadata
): Observable<ChatCompletionResponse>;
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import {
ThreadChannel,
} from 'discord.js';
import { List, Stack } from 'immutable';
import { lastValueFrom } from 'rxjs';
import { ReplaySubject, lastValueFrom, switchMap } from 'rxjs';
import {
DISCORD_BOT_MODULE_OPTIONS,
DiSCORD_SPLIT_MESSAGE_TARGET,
Expand All @@ -23,6 +23,7 @@ import { splitString } from '../../utils/split-string';
import { delay } from '../../utils/delay';
import {
ChatGPTService,
ChatOptions,
LLMAIService,
} from '../../interface/chatgpt.service.interface';
import { DiscordClientService } from '../discord-client/discord-client.service';
Expand Down Expand Up @@ -297,21 +298,24 @@ export class MessageService implements OnModuleInit {
try {
// show typing
await (message.channel as TextChannel).sendTyping();

const result = await lastValueFrom(
this.llmAIService.chat(
{
userInput: message.content,
},
this.metadata
const request$ = new ReplaySubject<ChatOptions>();
request$.next({
userInput: message.content,
});
request$.complete();
this.llmAIService
.chatStream(request$, this.metadata)
.pipe(
switchMap(async (result) => {
await this.sendMessageReply(result?.data, message);

this.asgardLogger.log(
`successfully send message: ${result?.data}`
);
this.asgardLogger.log(`successfully send message`);
})
)
);

const response = result.response;

await this.sendMessageReply(response, message);
this.asgardLogger.log(`successfully send message: ${response}`);
this.asgardLogger.log(`successfully send message`);
.subscribe();
} catch (e) {
this.asgardLogger.error(e);
await this.sendMessageReply(
Expand Down
Loading

0 comments on commit 7fb8de9

Please sign in to comment.