diff --git a/ios/OpenAIKit/Chat/ChatInput.swift b/ios/OpenAIKit/Chat/ChatInput.swift new file mode 100644 index 0000000..cdeaca7 --- /dev/null +++ b/ios/OpenAIKit/Chat/ChatInput.swift @@ -0,0 +1,37 @@ +import Foundation + +struct JSMessage: Codable { + let role: String + let content: String +} + +struct ChatInput: Codable, ModelID { + let model: String + var messages: [JSMessage] + let temperature: Double? + let topP: Double? + let n: Int? + let stops: [String]? + var maxTokens: Int? + var presencePenalty: Double? + var frequencyPenalty: Double? + var logitBias: [String : Int]? + var user: String? + + var id: String { + model + } + + var toMessages: [Chat.Message] { + messages.map { message in + switch message.role { + case "system": + return .system(content: message.content) + case "user": + return .user(content: message.content) + default: + return .assistant(content: message.content) + } + } + } +} diff --git a/ios/OpenAIKit/Chat/ChatProvider.swift b/ios/OpenAIKit/Chat/ChatProvider.swift index 5a6a236..afab050 100644 --- a/ios/OpenAIKit/Chat/ChatProvider.swift +++ b/ios/OpenAIKit/Chat/ChatProvider.swift @@ -9,9 +9,9 @@ public struct ChatProvider { /** Create chat completion POST - + https://api.openai.com/v1/chat/completions - + Creates a chat completion for the provided prompt and parameters */ public func create( @@ -42,17 +42,17 @@ public struct ChatProvider { logitBias: logitBias, user: user ) - + return try await requestHandler.perform(request: request) - + } /** Create chat completion POST - + https://api.openai.com/v1/chat/completions - + Creates a chat completion for the provided prompt and parameters stream If set, partial message deltas will be sent, like in ChatGPT. @@ -88,8 +88,8 @@ public struct ChatProvider { logitBias: logitBias, user: user ) - + return try await requestHandler.stream(request: request) - + } } diff --git a/ios/OpenAIKit/Model/Model.swift b/ios/OpenAIKit/Model/Model.swift index f7cc45c..97ac2b0 100644 --- a/ios/OpenAIKit/Model/Model.swift +++ b/ios/OpenAIKit/Model/Model.swift @@ -30,7 +30,7 @@ extension Model { } } -public protocol ModelID { +public protocol ModelID: Codable { var id: String { get } } diff --git a/ios/ReactNativeOpenai.mm b/ios/ReactNativeOpenai.mm index c3aca76..adf814d 100644 --- a/ios/ReactNativeOpenai.mm +++ b/ios/ReactNativeOpenai.mm @@ -4,7 +4,7 @@ @interface RCT_EXTERN_MODULE(ReactNativeOpenai, RCTEventEmitter) RCT_EXTERN_METHOD(supportedEvents) -RCT_EXTERN_METHOD(stream:(NSString *)prompt) +RCT_EXTERN_METHOD(stream:(NSDictionary *)input) RCT_EXTERN_METHOD(initialize:(NSString *)apiKey organization:(NSString *)organization) + (BOOL)requiresMainQueueSetup diff --git a/ios/ReactNativeOpenai.swift b/ios/ReactNativeOpenai.swift index 16be056..b605bd7 100644 --- a/ios/ReactNativeOpenai.swift +++ b/ios/ReactNativeOpenai.swift @@ -4,11 +4,6 @@ import React @objc(ReactNativeOpenai) final class ReactNativeOpenai: RCTEventEmitter { - struct Action { - var action: String - var payload: Any! - } - let urlSession = URLSession(configuration: .default) var configuration: Configuration? lazy var openAIClient = Client(session: urlSession, configuration: configuration!) @@ -19,8 +14,6 @@ final class ReactNativeOpenai: RCTEventEmitter { private static var queue: [Action] = [] - private static let onMessageRecived = "onMessageReceived" - @objc override init() { super.init() Self.emitter = self @@ -31,29 +24,30 @@ final class ReactNativeOpenai: RCTEventEmitter { self.configuration = Configuration(apiKey: apiKey, organization: organization) } - @objc public override func constantsToExport() -> [AnyHashable : Any]! { - return ["ON_STORE_ACTION": Self.onMessageRecived] - } - override public static func requiresMainQueueSetup() -> Bool { return true } @objc public override func supportedEvents() -> [String] { - [Self.onMessageRecived] + ["onChatMessageReceived"] + } + + struct Action { + let type: String + let payload: String } private static func sendStoreAction(_ action: Action) { if let emitter = self.emitter { - emitter.sendEvent(withName: onMessageRecived, body: [ - "type": action.action, + emitter.sendEvent(withName: "onChatMessageReceived", body: [ + "type": action.type, "payload": action.payload ]) } } - @objc public static func dispatch(action: String, payload: Any!) { - let actionObj = Action(action: action, payload: payload) + @objc public static func dispatch(type: String, payload: String) { + let actionObj = Action(type: type, payload: payload) if isInitialized { self.sendStoreAction(actionObj) } else { @@ -77,23 +71,30 @@ final class ReactNativeOpenai: RCTEventEmitter { @available(iOS 15.0, *) extension ReactNativeOpenai { @objc(stream:) - public func stream(prompt: String) { + public func stream(input: NSDictionary) { Task { do { + let decoded = try DictionaryDecoder().decode(ChatInput.self, from: input) let completion = try await openAIClient.chats.stream( - model: Model.GPT3.gpt3_5Turbo, - messages: [.user(content: prompt)] + model: decoded, + messages: decoded.toMessages, + temperature: decoded.temperature ?? 1, + topP: decoded.topP ?? 1, + n: decoded.n ?? 1, + stops: decoded.stops ?? [], + maxTokens: decoded.maxTokens, + presencePenalty: decoded.presencePenalty ?? 0, + frequencyPenalty: decoded.frequencyPenalty ?? 0, + logitBias: decoded.logitBias ?? [:], + user: decoded.user ) for try await chat in completion { - if let streamMessage = chat.choices.first?.delta.content { - print("Stream message: \(streamMessage)") - Self.dispatch(action: Self.onMessageRecived, payload: [ - "message": streamMessage - ]) - } + if let payload = String(data: try JSONEncoder().encode(chat), encoding: .utf8) { + Self.dispatch(type: "onChatMessageReceived", payload: payload) + } } } catch { - print("j",error) + print("error", error) } } } diff --git a/src/index.tsx b/src/index.tsx index abcf8a6..39d542e 100644 --- a/src/index.tsx +++ b/src/index.tsx @@ -1,44 +1,105 @@ -import { NativeEventEmitter, NativeModules, Platform } from 'react-native'; - -const LINKING_ERROR = - `The package 'react-native-openai' doesn't seem to be linked. Make sure: \n\n` + - Platform.select({ ios: "- You have run 'pod install'\n", default: '' }) + - '- You rebuilt the app after installing the package\n' + - '- You are not using Expo Go\n'; - -const ReactNativeOpenAI = NativeModules.ReactNativeOpenai - ? NativeModules.ReactNativeOpenai - : new Proxy( - {}, - { - get() { - throw new Error(LINKING_ERROR); - }, - } - ); - -export type EventTypes = 'onMessageReceived'; +import { NativeEventEmitter, NativeModules } from 'react-native'; class OpenAI { + module = NativeModules.ReactNativeOpenai; private bridge: NativeEventEmitter; + public chat: Chat; public constructor(apiKey: string, organization: string) { - this.bridge = new NativeEventEmitter(ReactNativeOpenAI); - ReactNativeOpenAI.initialize(apiKey, organization); + this.bridge = new NativeEventEmitter(this.module); + this.module.initialize(apiKey, organization); + this.chat = new Chat(this.module, this.bridge); + } +} + +namespace ChatModels { + type Model = + | 'gpt-4' + | 'gpt-4-0314' + | 'gpt-4-32k' + | 'gpt-4-32k-0314' + | 'gpt-3.5-turbo' + | 'gpt-3.5-turbo-16k' + | 'gpt-3.5-turbo-0301' + | 'text-davinci-003' + | 'text-davinci-002' + | 'text-curie-001' + | 'text-babbage-001' + | 'text-ada-001' + | 'text-embedding-ada-002' + | 'text-davinci-001' + | 'text-davinci-edit-001' + | 'davinci-instruct-beta' + | 'davinci' + | 'curie-instruct-beta' + | 'curie' + | 'ada' + | 'babbage' + | 'code-davinci-002' + | 'code-cushman-001' + | 'code-davinci-001' + | 'code-davinci-edit-001' + | 'whisper-1'; + + type Message = { + role: 'user' | 'system' | 'assistant'; + content: string; + }; + + export type StreamInput = { + model: Model; + messages: Message[]; + temperature?: number; + topP?: number; + n?: number; + stops?: string[]; + maxTokens?: number; + presencePenalty?: number; + frequencyPenalty?: number; + logitBias?: { [key: string]: number }; + user?: string; + }; + + export type StreamOutput = { + id: string; + object: string; + created: number; + model: Model; + choices: { + delta: { + content?: string; + role?: string; + }; + index: number; + finishReason: 'length' | 'stop' | 'content_filter'; + }[]; + }; +} + +class Chat { + private bridge: NativeEventEmitter; + private module: any; + + public constructor(module: any, bridge: NativeEventEmitter) { + this.module = module; + this.bridge = bridge; } - public createCompletion(prompt: string) { - return ReactNativeOpenAI.stream(prompt); + public stream(input: ChatModels.StreamInput) { + return this.module.stream(input); } public addListener( - event: EventTypes, - callback: (event: { payload: { message: string } }) => void + event: 'onChatMessageReceived', + callback: (event: ChatModels.StreamOutput) => void ) { - this.bridge.addListener(event, callback); + this.bridge.addListener(event, (value) => { + const payload = JSON.parse(value.payload); + callback(payload); + }); } - public removeListener(event: EventTypes) { + public removeListener(event: 'onChatMessageReceived') { this.bridge.removeAllListeners(event); } }