Skip to content

Commit

Permalink
Merge pull request #81 from astrokin/add_proxy
Browse files Browse the repository at this point in the history
Add custom server proxy support
  • Loading branch information
adamrushy authored Jul 11, 2023
2 parents 456aace + 1822a85 commit 2cac42a
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 51 deletions.
78 changes: 46 additions & 32 deletions Sources/OpenAISwift/OpenAIEndpoint.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4,44 +4,58 @@

import Foundation

enum Endpoint {
case completions
case edits
case chat
case images
case embeddings
case moderations
}

extension Endpoint {
var path: String {
switch self {
case .completions:
return "/v1/completions"
case .edits:
return "/v1/edits"
case .chat:
return "/v1/chat/completions"
case .images:
return "/v1/images/generations"
case .embeddings:
return "/v1/embeddings"
case .moderations:
return "/v1/moderations"
}
public struct OpenAIEndpointProvider {
public enum API {
case completions
case edits
case chat
case images
case embeddings
case moderations
}

var method: String {
switch self {
case .completions, .edits, .chat, .images, .embeddings, .moderations:
return "POST"
public enum Source {
case openAI
case proxy(path: ((API) -> String), method: ((API) -> String))
}

public let source: Source

public init(source: OpenAIEndpointProvider.Source) {
self.source = source
}

func getPath(api: API) -> String {
switch source {
case .openAI:
switch api {
case .completions:
return "/v1/completions"
case .edits:
return "/v1/edits"
case .chat:
return "/v1/chat/completions"
case .images:
return "/v1/images/generations"
case .embeddings:
return "/v1/embeddings"
case .moderations:
return "/v1/moderations"
}
case let .proxy(path: pathClosure, method: _):
return pathClosure(api)
}
}

func baseURL() -> String {
switch self {
func getMethod(api: API) -> String {
switch source {
case .openAI:
switch api {
case .completions, .edits, .chat, .images, .embeddings, .moderations:
return "https://api.openai.com"
return "POST"
}
case let .proxy(path: _, method: methodClosure):
return methodClosure(api)
}
}
}
48 changes: 29 additions & 19 deletions Sources/OpenAISwift/OpenAISwift.swift
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,36 @@ public enum OpenAIError: Error {
}

public class OpenAISwift {
fileprivate(set) var token: String?
fileprivate let config: Config

/// Configuration object for the client
public struct Config {

/// Initialiser
/// - Parameter session: the session to use for network requests.
public init(session: URLSession = URLSession.shared) {
public init(baseURL: String, endpointPrivider: OpenAIEndpointProvider, session: URLSession, authorizeRequest: @escaping (inout URLRequest) -> Void) {
self.baseURL = baseURL
self.endpointProvider = endpointPrivider
self.authorizeRequest = authorizeRequest
self.session = session
}

let baseURL: String
let endpointProvider: OpenAIEndpointProvider
let session:URLSession
let authorizeRequest: (inout URLRequest) -> Void

public static func makeDefultOpenAI(api_key: String) -> Self {
.init(baseURL: "https://api.openai.com",
endpointPrivider: OpenAIEndpointProvider(source: .openAI),
session: .shared,
authorizeRequest: { request in
request.setValue("Bearer \(api_key)", forHTTPHeaderField: "Authorization")
})
}
}

public init(authToken: String, config: Config = Config()) {
self.token = authToken
self.config = Config()
public init(config: Config) {
self.config = config
}
}

Expand All @@ -40,7 +52,7 @@ extension OpenAISwift {
/// - maxTokens: The limit character for the returned response, defaults to 16 as per the API
/// - completionHandler: Returns an OpenAI Data Model
public func sendCompletion(with prompt: String, model: OpenAIModelType = .gpt3(.davinci), maxTokens: Int = 16, temperature: Double = 1, completionHandler: @escaping (Result<OpenAI<TextResult>, OpenAIError>) -> Void) {
let endpoint = Endpoint.completions
let endpoint = OpenAIEndpointProvider.API.completions
let body = Command(prompt: prompt, model: model.modelName, maxTokens: maxTokens, temperature: temperature)
let request = prepareRequest(endpoint, body: body)

Expand All @@ -66,7 +78,7 @@ extension OpenAISwift {
/// - input: The Input For Example "My nam is Adam"
/// - completionHandler: Returns an OpenAI Data Model
public func sendEdits(with instruction: String, model: OpenAIModelType = .feature(.davinci), input: String = "", completionHandler: @escaping (Result<OpenAI<TextResult>, OpenAIError>) -> Void) {
let endpoint = Endpoint.edits
let endpoint = OpenAIEndpointProvider.API.edits
let body = Instruction(instruction: instruction, model: model.modelName, input: input)
let request = prepareRequest(endpoint, body: body)

Expand All @@ -91,7 +103,7 @@ extension OpenAISwift {
/// - model: The Model to use
/// - completionHandler: Returns an OpenAI Data Model
public func sendModerations(with input: String, model: OpenAIModelType = .moderation(.latest), completionHandler: @escaping (Result<OpenAI<ModerationResult>, OpenAIError>) -> Void) {
let endpoint = Endpoint.moderations
let endpoint = OpenAIEndpointProvider.API.moderations
let body = Moderation(input: input, model: model.modelName)
let request = prepareRequest(endpoint, body: body)

Expand Down Expand Up @@ -136,7 +148,7 @@ extension OpenAISwift {
frequencyPenalty: Double? = 0,
logitBias: [Int: Double]? = nil,
completionHandler: @escaping (Result<OpenAI<MessageResult>, OpenAIError>) -> Void) {
let endpoint = Endpoint.chat
let endpoint = OpenAIEndpointProvider.API.chat
let body = ChatConversation(user: user,
messages: messages,
model: model.modelName,
Expand Down Expand Up @@ -180,7 +192,7 @@ extension OpenAISwift {
public func sendEmbeddings(with input: String,
model: OpenAIModelType = .embedding(.ada),
completionHandler: @escaping (Result<OpenAI<EmbeddingResult>, OpenAIError>) -> Void) {
let endpoint = Endpoint.embeddings
let endpoint = OpenAIEndpointProvider.API.embeddings
let body = EmbeddingsInput(input: input,
model: model.modelName)

Expand Down Expand Up @@ -209,7 +221,7 @@ extension OpenAISwift {
/// - user: An optional unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
/// - completionHandler: Returns an OpenAI Data Model
public func sendImages(with prompt: String, numImages: Int = 1, size: ImageSize = .size1024, user: String? = nil, completionHandler: @escaping (Result<OpenAI<UrlResult>, OpenAIError>) -> Void) {
let endpoint = Endpoint.images
let endpoint = OpenAIEndpointProvider.API.images
let body = ImageGeneration(prompt: prompt, n: numImages, size: size, user: user)
let request = prepareRequest(endpoint, body: body)

Expand Down Expand Up @@ -241,15 +253,13 @@ extension OpenAISwift {
task.resume()
}

private func prepareRequest<BodyType: Encodable>(_ endpoint: Endpoint, body: BodyType) -> URLRequest {
var urlComponents = URLComponents(url: URL(string: endpoint.baseURL())!, resolvingAgainstBaseURL: true)
urlComponents?.path = endpoint.path
private func prepareRequest<BodyType: Encodable>(_ endpoint: OpenAIEndpointProvider.API, body: BodyType) -> URLRequest {
var urlComponents = URLComponents(url: URL(string: config.baseURL)!, resolvingAgainstBaseURL: true)
urlComponents?.path = config.endpointProvider.getPath(api: endpoint)
var request = URLRequest(url: urlComponents!.url!)
request.httpMethod = endpoint.method
request.httpMethod = config.endpointProvider.getMethod(api: endpoint)

if let token = self.token {
request.setValue("Bearer \(token)", forHTTPHeaderField: "Authorization")
}
config.authorizeRequest(&request)

request.setValue("application/json", forHTTPHeaderField: "content-type")

Expand Down

0 comments on commit 2cac42a

Please sign in to comment.