Skip to content

Commit

Permalink
Merge pull request #63 from bluepi0j/AddChatStreamSupport
Browse files Browse the repository at this point in the history
Add support for chat stream
  • Loading branch information
adamrushy authored Jul 11, 2023
2 parents 2cac42a + 97111b1 commit dd98cdf
Show file tree
Hide file tree
Showing 4 changed files with 176 additions and 4 deletions.
9 changes: 7 additions & 2 deletions Sources/OpenAISwift/Models/ChatMessage.swift
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ public enum ChatRole: String, Codable {
/// A structure that represents a single message in a chat conversation.
public struct ChatMessage: Codable {
/// The role of the sender of the message.
public let role: ChatRole
public let role: ChatRole?
/// The content of the message.
public let content: String
public let content: String?

/// Creates a new chat message with a given role and content.
/// - Parameters:
Expand Down Expand Up @@ -68,6 +68,10 @@ public struct ChatConversation: Encodable {

/// Modify the likelihood of specified tokens appearing in the completion. Maps tokens (specified by their token ID in the OpenAI Tokenizer—not English words) to an associated bias value from -100 to 100. Values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should result in a ban or exclusive selection of the relevant token.
let logitBias: [Int: Double]?

/// If you're generating long completions, waiting for the response can take many seconds. To get responses sooner, you can 'stream' the completion as it's being generated. This allows you to start printing or processing the beginning of the completion before the full completion is finished.
/// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_stream_completions.ipynb
let stream: Bool?

enum CodingKeys: String, CodingKey {
case user
Expand All @@ -81,6 +85,7 @@ public struct ChatConversation: Encodable {
case presencePenalty = "presence_penalty"
case frequencyPenalty = "frequency_penalty"
case logitBias = "logit_bias"
case stream
}
}

Expand Down
4 changes: 4 additions & 0 deletions Sources/OpenAISwift/Models/OpenAI.swift
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ public struct MessageResult: Payload {
public let message: ChatMessage
}

public struct StreamMessageResult: Payload {
public let delta: ChatMessage
}

public struct UsageResult: Codable {
public let promptTokens: Int
public let completionTokens: Int?
Expand Down
102 changes: 100 additions & 2 deletions Sources/OpenAISwift/OpenAISwift.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ public enum OpenAIError: Error {

public class OpenAISwift {
fileprivate let config: Config

fileprivate let handler = ServerSentEventsHandler()

/// Configuration object for the client
public struct Config {

Expand Down Expand Up @@ -159,7 +160,8 @@ extension OpenAISwift {
maxTokens: maxTokens,
presencePenalty: presencePenalty,
frequencyPenalty: frequencyPenalty,
logitBias: logitBias)
logitBias: logitBias,
stream: false)

let request = prepareRequest(endpoint, body: body)

Expand Down Expand Up @@ -211,6 +213,53 @@ extension OpenAISwift {
}
}
}

/// Send a Chat request to the OpenAI API with stream enabled
/// - Parameters:
/// - messages: Array of `ChatMessages`
/// - model: The Model to use, the only support model is `gpt-3.5-turbo`
/// - user: A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
/// - temperature: What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. We generally recommend altering this or topProbabilityMass but not both.
/// - topProbabilityMass: The OpenAI api equivalent of the "top_p" parameter. An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. We generally recommend altering this or temperature but not both.
/// - choices: How many chat completion choices to generate for each input message.
/// - stop: Up to 4 sequences where the API will stop generating further tokens.
/// - maxTokens: The maximum number of tokens allowed for the generated answer. By default, the number of tokens the model can return will be (4096 - prompt tokens).
/// - presencePenalty: Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.
/// - frequencyPenalty: Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.
/// - logitBias: Modify the likelihood of specified tokens appearing in the completion. Maps tokens (specified by their token ID in the OpenAI Tokenizer—not English words) to an associated bias value from -100 to 100. Values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should result in a ban or exclusive selection of the relevant token.
/// - onEventReceived: Called Multiple times, returns an OpenAI Data Model
/// - onComplete: Triggers when sever complete sending the message
public func sendStreamingChat(with messages: [ChatMessage],
model: OpenAIModelType = .chat(.chatgpt),
user: String? = nil,
temperature: Double? = 1,
topProbabilityMass: Double? = 0,
choices: Int? = 1,
stop: [String]? = nil,
maxTokens: Int? = nil,
presencePenalty: Double? = 0,
frequencyPenalty: Double? = 0,
logitBias: [Int: Double]? = nil,
onEventReceived: ((Result<OpenAI<StreamMessageResult>, OpenAIError>) -> Void)? = nil,
onComplete: (() -> Void)? = nil) {
let endpoint = Endpoint.chat
let body = ChatConversation(user: user,
messages: messages,
model: model.modelName,
temperature: temperature,
topProbabilityMass: topProbabilityMass,
choices: choices,
stop: stop,
maxTokens: maxTokens,
presencePenalty: presencePenalty,
frequencyPenalty: frequencyPenalty,
logitBias: logitBias,
stream: true)
let request = prepareRequest(endpoint, body: body)
handler.onEventReceived = onEventReceived
handler.onComplete = onComplete
handler.connect(with: request)
}


/// Send a Image generation request to the OpenAI API
Expand Down Expand Up @@ -352,6 +401,55 @@ extension OpenAISwift {
}
}
}


/// Send a Chat request to the OpenAI API with stream enabled
/// - Parameters:
/// - messages: Array of `ChatMessages`
/// - model: The Model to use, the only support model is `gpt-3.5-turbo`
/// - user: A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
/// - temperature: What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. We generally recommend altering this or topProbabilityMass but not both.
/// - topProbabilityMass: The OpenAI api equivalent of the "top_p" parameter. An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. We generally recommend altering this or temperature but not both.
/// - choices: How many chat completion choices to generate for each input message.
/// - stop: Up to 4 sequences where the API will stop generating further tokens.
/// - maxTokens: The maximum number of tokens allowed for the generated answer. By default, the number of tokens the model can return will be (4096 - prompt tokens).
/// - presencePenalty: Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.
/// - frequencyPenalty: Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.
/// - logitBias: Modify the likelihood of specified tokens appearing in the completion. Maps tokens (specified by their token ID in the OpenAI Tokenizer—not English words) to an associated bias value from -100 to 100. Values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should result in a ban or exclusive selection of the relevant token.
/// - Returns: Returns an OpenAI Data Model
@available(swift 5.5)
@available(macOS 10.15, iOS 13, watchOS 6, tvOS 13, *)
public func sendStreamingChat(with messages: [ChatMessage],
model: OpenAIModelType = .chat(.chatgpt),
user: String? = nil,
temperature: Double? = 1,
topProbabilityMass: Double? = 0,
choices: Int? = 1,
stop: [String]? = nil,
maxTokens: Int? = nil,
presencePenalty: Double? = 0,
frequencyPenalty: Double? = 0,
logitBias: [Int: Double]? = nil) -> AsyncStream<Result<OpenAI<StreamMessageResult>, OpenAIError>> {
return AsyncStream { continuation in
sendStreamingChat(
with: messages,
model: model,
user: user,
temperature: temperature,
topProbabilityMass: topProbabilityMass,
choices: choices,
stop: stop,
maxTokens: maxTokens,
presencePenalty: presencePenalty,
frequencyPenalty: frequencyPenalty,
logitBias: logitBias,
onEventReceived: { result in
continuation.yield(result)
}) {
continuation.finish()
}
}
}

/// Send a Embeddings request to the OpenAI API
/// - Parameters:
Expand Down
65 changes: 65 additions & 0 deletions Sources/OpenAISwift/ServerSentEventsHandler.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
//
// ServerSentEventsHandler.swift
//
//
// Created by Vic on 2023-03-25.
//

import Foundation

class ServerSentEventsHandler: NSObject {

var onEventReceived: ((Result<OpenAI<StreamMessageResult>, OpenAIError>) -> Void)?
var onComplete: (() -> Void)?

private lazy var session: URLSession = URLSession(configuration: .default, delegate: self, delegateQueue: nil)
private var task: URLSessionDataTask?

func connect(with request: URLRequest) {
task = session.dataTask(with: request)
task?.resume()
}

func disconnect() {
task?.cancel()
}

func processEvent(_ eventData: Data) {
do {
let res = try JSONDecoder().decode(OpenAI<StreamMessageResult>.self, from: eventData)
onEventReceived?(.success(res))
} catch {
onEventReceived?(.failure(.decodingError(error: error)))
}
}
}

extension ServerSentEventsHandler: URLSessionDataDelegate {

/// It will be called several times, each time could return one chunk of data or multiple chunk of data
/// The JSON look liks this:
/// `data: {"id":"chatcmpl-6yVTvD6UAXsE9uG2SmW4Tc2iuFnnT","object":"chat.completion.chunk","created":1679878715,"model":"gpt-3.5-turbo-0301","choices":[{"delta":{"role":"assistant"},"index":0,"finish_reason":null}]}`
/// `data: {"id":"chatcmpl-6yVTvD6UAXsE9uG2SmW4Tc2iuFnnT","object":"chat.completion.chunk","created":1679878715,"model":"gpt-3.5-turbo-0301","choices":[{"delta":{"content":"Once"},"index":0,"finish_reason":null}]}`
func urlSession(_ session: URLSession, dataTask: URLSessionDataTask, didReceive data: Data) {
if let eventString = String(data: data, encoding: .utf8) {
let lines = eventString.split(separator: "\n")
for line in lines {
if line.hasPrefix("data:") && line != "data: [DONE]" {
if let eventData = String(line.dropFirst(5)).data(using: .utf8) {
processEvent(eventData)
} else {
disconnect()
}
}
}
}
}

func urlSession(_ session: URLSession, task: URLSessionTask, didCompleteWithError error: Error?) {
if let error = error {
onEventReceived?(.failure(.genericError(error: error)))
} else {
onComplete?()
}
}
}

0 comments on commit dd98cdf

Please sign in to comment.