Skip to content

Commit

Permalink
Update enum as discussed with @raulraja
Browse files Browse the repository at this point in the history
  • Loading branch information
nomisRev committed Mar 26, 2024
1 parent 8f1e1f0 commit 08657e8
Show file tree
Hide file tree
Showing 104 changed files with 322 additions and 290 deletions.
4 changes: 2 additions & 2 deletions core/src/commonMain/kotlin/com/xebia/functional/xef/AI.kt
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ sealed interface AI {
input: String,
output: String,
context: String,
model: CreateChatCompletionRequestModel = CreateChatCompletionRequestModel._4_1106_preview,
model: CreateChatCompletionRequestModel = CreateChatCompletionRequestModel.gpt_4_1106_preview,
target: KType = typeOf<E>(),
config: Config = Config(),
api: Chat = OpenAI(config).chat,
Expand All @@ -107,7 +107,7 @@ sealed interface AI {
suspend inline operator fun <reified A : Any> invoke(
prompt: String,
target: KType = typeOf<A>(),
model: CreateChatCompletionRequestModel = CreateChatCompletionRequestModel._4_1106_preview,
model: CreateChatCompletionRequestModel = CreateChatCompletionRequestModel.gpt_4_1106_preview,
config: Config = Config(),
api: Chat = OpenAI(config).chat,
conversation: Conversation = Conversation()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ import com.xebia.functional.openai.generated.model.Embedding
suspend fun Embeddings.embedDocuments(
texts: List<String>,
chunkSize: Int = 400,
embeddingRequestModel: CreateEmbeddingRequestModel = CreateEmbeddingRequestModel.ada_002
embeddingRequestModel: CreateEmbeddingRequestModel =
CreateEmbeddingRequestModel.text_embedding_ada_002
): List<Embedding> =
if (texts.isEmpty()) emptyList()
else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ suspend fun Prompt.addMetrics(conversation: Conversation) {

conversation.metric.parameter(
"openai.chat_completion.prompt.messages_roles",
messages.map { it.completionRole().value }
messages.map { it.completionRole().name }
)
conversation.metric.parameter(
"openai.chat_completion.prompt.last-message",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ suspend inline fun <reified A> Chat.visionStructured(
prompt: String,
url: String,
conversation: Conversation = Conversation(),
model: CreateChatCompletionRequestModel = CreateChatCompletionRequestModel._4_vision_preview
model: CreateChatCompletionRequestModel = CreateChatCompletionRequestModel.gpt_4_vision_preview
): A {
val response = vision(prompt, url, conversation).toList().joinToString("") { it }
return prompt(Prompt(model) { +user(response) }, conversation, serializer())
Expand All @@ -32,7 +32,7 @@ fun Chat.vision(
conversation: Conversation = Conversation()
): Flow<String> =
promptStreaming(
prompt = Prompt(CreateChatCompletionRequestModel._4_vision_preview) { +image(prompt, url) },
prompt = Prompt(CreateChatCompletionRequestModel.gpt_4_vision_preview) { +image(prompt, url) },
scope = conversation
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@ import com.xebia.functional.openai.generated.model.CreateEmbeddingRequestModel
import com.xebia.functional.tokenizer.ModelType

fun CreateChatCompletionRequestModel.modelType(forFunctions: Boolean = false): ModelType {
val stringValue = value
val stringValue = name
val forFunctionsModel = ModelType.functionSpecific.find { forFunctions && it.name == stringValue }
return forFunctionsModel
?: (ModelType.all.find { it.name == stringValue } ?: ModelType.TODO(stringValue))
}

fun CreateEmbeddingRequestModel.modelType(): ModelType {
val stringValue = value
val stringValue = name
return ModelType.all.find { it.name == stringValue } ?: ModelType.TODO(stringValue)
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class LogsMetric(private val level: Level = Level.INFO) : Metric {
this.message = "${writeIndent(numberOfBlocks.get())}|-- RunId: ${runObject.id}"
}
logger.at(level) {
this.message = "${writeIndent(numberOfBlocks.get())}|-- Status: ${runObject.status.value}"
this.message = "${writeIndent(numberOfBlocks.get())}|-- Status: ${runObject.status.name}"
}
}

Expand Down Expand Up @@ -93,7 +93,7 @@ class LogsMetric(private val level: Level = Level.INFO) : Metric {
this.message = "${writeIndent(numberOfBlocks.get())}|-- RunId: ${output.runId}"
}
logger.at(level) {
this.message = "${writeIndent(numberOfBlocks.get())}|-- Status: ${output.status.value}"
this.message = "${writeIndent(numberOfBlocks.get())}|-- Status: ${output.status.name}"
}
return output
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,12 @@ interface PromptBuilder {
): ChatCompletionRequestMessage {
val content = "${contentAsString()}\n${message.contentAsString()}"
return when (completionRole()) {
ChatCompletionRole.system -> system(content)
ChatCompletionRole.user -> user(content)
ChatCompletionRole.assistant -> assistant(content)
ChatCompletionRole.tool -> error("Tool role is not supported")
ChatCompletionRole.function -> error("Function role is not supported")
ChatCompletionRole.Supported.system -> system(content)
ChatCompletionRole.Supported.user -> user(content)
ChatCompletionRole.Supported.assistant -> assistant(content)
ChatCompletionRole.Supported.tool -> error("Tool role is not supported")
ChatCompletionRole.Supported.function -> error("Function role is not supported")
is ChatCompletionRole.Custom -> error("Custom roles are not supported")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ data class PromptConfiguration
@JvmOverloads
constructor(
var maxDeserializationAttempts: Int = 3,
var user: String = ChatCompletionRole.user.value,
var user: String = ChatCompletionRole.user.name,
var temperature: Double = 0.4,
var numberOfPredictions: Int = 1,
var docsInContext: Int = 5,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ private constructor(
) : VectorStore {
constructor(
embeddings: Embeddings,
embeddingRequestModel: CreateEmbeddingRequestModel = CreateEmbeddingRequestModel.ada_002
embeddingRequestModel: CreateEmbeddingRequestModel =
CreateEmbeddingRequestModel.text_embedding_ada_002
) : this(embeddings, Atomic(State.empty()), embeddingRequestModel)

override val indexValue: AtomicInt = AtomicInt(0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ sealed class MemorizedMessage {
val role: ChatCompletionRole
get() =
when (this) {
is Request -> ChatCompletionRole.valueOf(message.completionRole().value)
is Response -> ChatCompletionRole.valueOf(message.role.value)
is Request -> ChatCompletionRole.valueOf(message.completionRole().name)
is Response -> ChatCompletionRole.valueOf(message.role.name)
}

fun asRequestMessage(): ChatCompletionRequestMessage =
Expand All @@ -30,7 +30,7 @@ sealed class MemorizedMessage {

fun memorizedMessage(role: ChatCompletionRole, content: String): MemorizedMessage =
when (role) {
ChatCompletionRole.system ->
ChatCompletionRole.Supported.system ->
MemorizedMessage.Request(
ChatCompletionRequestMessage.Third(
ChatCompletionRequestSystemMessage(
Expand All @@ -39,7 +39,7 @@ fun memorizedMessage(role: ChatCompletionRole, content: String): MemorizedMessag
)
)
)
ChatCompletionRole.user ->
ChatCompletionRole.Supported.user ->
MemorizedMessage.Request(
ChatCompletionRequestMessage.Fifth(
ChatCompletionRequestUserMessage(
Expand All @@ -48,13 +48,14 @@ fun memorizedMessage(role: ChatCompletionRole, content: String): MemorizedMessag
)
)
)
ChatCompletionRole.assistant ->
ChatCompletionRole.Supported.assistant ->
MemorizedMessage.Response(
ChatCompletionResponseMessage(
content = content,
role = ChatCompletionResponseMessage.Role.assistant
)
)
ChatCompletionRole.tool -> error("Tool messages are not supported")
ChatCompletionRole.function -> error("Function messages are not supported")
ChatCompletionRole.Supported.tool -> error("Tool messages are not supported")
ChatCompletionRole.Supported.function -> error("Function messages are not supported")
is ChatCompletionRole.Custom -> error("Custom messages are not supported")
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class ConversationSpec :
val conversationId = ConversationId(UUID.generateUUID().toString())

val chatApi = TestChatApi()
val model = CreateChatCompletionRequestModel._4
val model = CreateChatCompletionRequestModel.gpt_4

val scope =
Conversation(
Expand Down Expand Up @@ -69,7 +69,7 @@ class ConversationSpec :
val vectorStore = scope.store

val chatApi = TestChatApi(responses = messages)
val model = CreateChatCompletionRequestModel._3_5_turbo
val model = CreateChatCompletionRequestModel.gpt_3_5_turbo

val totalTokens =
model.modelType().tokensFromMessages(messages.flatMap(::chatCompletionRequestMessages))
Expand Down Expand Up @@ -105,7 +105,7 @@ class ConversationSpec :
val vectorStore = scope.store

val chatApi = TestChatApi(responses = messages)
val model = CreateChatCompletionRequestModel._3_5_turbo_16k
val model = CreateChatCompletionRequestModel.gpt_3_5_turbo_16k

val totalTokens =
model.modelType().tokensFromMessages(messages.flatMap(::chatCompletionRequestMessages))
Expand Down Expand Up @@ -139,7 +139,7 @@ class ConversationSpec :
)

val chatApi = TestChatApi(message)
val model = CreateChatCompletionRequestModel._3_5_turbo_16k
val model = CreateChatCompletionRequestModel.gpt_3_5_turbo_16k

val response: Answer =
chatApi.prompt(
Expand Down Expand Up @@ -172,7 +172,7 @@ class ConversationSpec :
)

val chatApi = TestChatApi(message)
val model = CreateChatCompletionRequestModel._3_5_turbo_0613
val model = CreateChatCompletionRequestModel.gpt_3_5_turbo_0613

val response: Answer =
chatApi.prompt(
Expand All @@ -192,7 +192,7 @@ class ConversationSpec :
val conversationId = ConversationId(UUID.generateUUID().toString())

val chatApi = TestChatApi()
val model = CreateChatCompletionRequestModel._3_5_turbo
val model = CreateChatCompletionRequestModel.gpt_3_5_turbo

val scope =
Conversation(
Expand Down Expand Up @@ -255,7 +255,7 @@ class ConversationSpec :
val conversationId = ConversationId(UUID.generateUUID().toString())

val chatApi = TestChatApi()
val model = CreateChatCompletionRequestModel._3_5_turbo
val model = CreateChatCompletionRequestModel.gpt_3_5_turbo

val vectorStore = LocalVectorStore(TestEmbeddings())

Expand Down Expand Up @@ -288,7 +288,7 @@ class ConversationSpec :
val conversationId = ConversationId(UUID.generateUUID().toString())

val chatApi = TestChatApi()
val model = CreateChatCompletionRequestModel._3_5_turbo
val model = CreateChatCompletionRequestModel.gpt_3_5_turbo

val vectorStore = LocalVectorStore(TestEmbeddings())

Expand Down Expand Up @@ -317,7 +317,7 @@ class ConversationSpec :
val conversationId = ConversationId(UUID.generateUUID().toString())

val chatApi = TestChatApi()
val model = CreateChatCompletionRequestModel._3_5_turbo
val model = CreateChatCompletionRequestModel.gpt_3_5_turbo

val vectorStore = LocalVectorStore(TestEmbeddings())

Expand Down Expand Up @@ -348,7 +348,7 @@ class ConversationSpec :
val conversationId = ConversationId(UUID.generateUUID().toString())

val chatApi = TestChatApi()
val model = CreateChatCompletionRequestModel._3_5_turbo
val model = CreateChatCompletionRequestModel.gpt_3_5_turbo

val vectorStore = LocalVectorStore(TestEmbeddings())

Expand Down Expand Up @@ -377,7 +377,7 @@ class ConversationSpec :
val conversationId = ConversationId(UUID.generateUUID().toString())

val chatApi = TestChatApi()
val model = CreateChatCompletionRequestModel._3_5_turbo
val model = CreateChatCompletionRequestModel.gpt_3_5_turbo

val vectorStore = LocalVectorStore(TestEmbeddings())

Expand Down Expand Up @@ -409,7 +409,7 @@ class ConversationSpec :
val conversationId = ConversationId(UUID.generateUUID().toString())

val chatApi = TestChatApi()
val model = CreateChatCompletionRequestModel._3_5_turbo
val model = CreateChatCompletionRequestModel.gpt_3_5_turbo

val vectorStore = LocalVectorStore(TestEmbeddings())

Expand Down Expand Up @@ -445,7 +445,7 @@ class ConversationSpec :
val conversationId = ConversationId(UUID.generateUUID().toString())

val chatApi = TestChatApi()
val model = CreateChatCompletionRequestModel._3_5_turbo
val model = CreateChatCompletionRequestModel.gpt_3_5_turbo

val vectorStore = LocalVectorStore(TestEmbeddings())

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import io.kotest.matchers.shouldBe

class PromptBuilderSpec :
StringSpec({
val model = CreateChatCompletionRequestModel._4
val model = CreateChatCompletionRequestModel.gpt_4
"buildPrompt should return the expected messages" {
val messages =
Prompt(model) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import io.kotest.matchers.shouldBe

class CombinedVectorStoreSpec :
StringSpec({
val model = CreateChatCompletionRequestModel._3_5_turbo
val model = CreateChatCompletionRequestModel.gpt_3_5_turbo
"memories function should return all of messages combined in the right order" {
val memoryData = MemoryData()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import io.kotest.matchers.shouldBe

class LocalVectorStoreSpec :
StringSpec({
val model = CreateChatCompletionRequestModel._3_5_turbo
val model = CreateChatCompletionRequestModel.gpt_3_5_turbo
"memories function should return all of messages in the right order when the limit is greater than the number of stored messages" {
val localVectorStore = LocalVectorStore(TestEmbeddings())

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ data class SuiteSpec(
}
ItemResult(item.input, outputResults)
}
val suiteResults = SuiteResults(description, model.value, E::class.simpleName, items)
val suiteResults = SuiteResults(description, model.name, E::class.simpleName, items)
export(Json.encodeToString(suiteResults))
return suiteResults
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ private fun displayStepsStatus(step: AssistantThread.RunDelta.Step) {
val details = step.runStep.stepDetails
val type =
when (details) {
is RunStepObjectStepDetails.First -> details.value.type.value
is RunStepObjectStepDetails.Second -> details.value.type.value
is RunStepObjectStepDetails.First -> details.value.type.name
is RunStepObjectStepDetails.Second -> details.value.type.name
}
val calls =
when (details) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ suspend fun main() {
when (it) {
is AssistantThread.RunDelta.ReceivedMessage ->
println("received message: ${it.message.content.firstOrNull()}")
is AssistantThread.RunDelta.Run -> println("run: ${it.message.status.value}")
is AssistantThread.RunDelta.Step -> println("step: ${it.runStep.type.value}")
is AssistantThread.RunDelta.Run -> println("run: ${it.message.status.name}")
is AssistantThread.RunDelta.Step -> println("step: ${it.runStep.type.name}")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ suspend fun main() {
val configNoneFromConversation = PromptConfiguration {
messagePolicy { addMessagesFromConversation = MessagesFromHistory.NONE }
}
val model = CreateChatCompletionRequestModel._3_5_turbo_16k_0613
val model = CreateChatCompletionRequestModel.gpt_3_5_turbo_16k_0613
val animal: Animal =
AI(
Prompt(model) { +user("A unique animal species.") }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ suspend fun main() {
val channel =
audio.createSpeech(
CreateSpeechRequest(
model = CreateSpeechRequestModel._1,
model = CreateSpeechRequestModel.tts_1,
input = modelResponse,
voice = CreateSpeechRequest.Voice.nova
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ data class Book(
)

suspend fun books(topic: String): Books {
val model = CreateChatCompletionRequestModel._3_5_turbo_16k_0613
val model = CreateChatCompletionRequestModel.gpt_3_5_turbo_16k_0613

val myCustomPrompt =
Prompt(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ suspend fun main() {
input = "Do I love Xef?",
output = "I have three opened PRs",
context = "The answer responds the question",
model = CreateChatCompletionRequestModel._3_5_turbo_0125
model = CreateChatCompletionRequestModel.gpt_3_5_turbo_0125
)
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@ object TestExample {

@JvmStatic
fun main(args: Array<String>) = SuspendApp {
val model = CreateChatCompletionRequestModel._3_5_turbo_16k
val model = CreateChatCompletionRequestModel.gpt_3_5_turbo_16k
val chat = OpenAI().chat

val spec =
SuiteSpec(
description = "Check GTP3.5 and fake outputs",
model = CreateChatCompletionRequestModel._4_turbo_preview
model = CreateChatCompletionRequestModel.gpt_4_turbo_preview
) {
val gpt35Description = OutputDescription("Using GPT3.5")
val fakeOutputs = OutputDescription("Fake outputs with errors")
Expand Down
Loading

0 comments on commit 08657e8

Please sign in to comment.