Skip to content

Commit

Permalink
Implement Ollama API
Browse files Browse the repository at this point in the history
  • Loading branch information
Taewan-P committed Oct 6, 2024
1 parent 51401ca commit 642b13e
Showing 1 changed file with 21 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class ChatRepositoryImpl @Inject constructor(

private lateinit var openAI: OpenAI
private lateinit var google: GenerativeModel
private lateinit var ollama: OpenAI

override suspend fun completeOpenAIChat(question: Message, history: List<Message>): Flow<ApiState> {
val platform = checkNotNull(settingRepository.fetchPlatforms().firstOrNull { it.name == ApiType.OPENAI })
Expand All @@ -63,7 +64,7 @@ class ChatRepositoryImpl @Inject constructor(
)

return openAI.chatCompletions(chatCompletionRequest)
.map<ChatCompletionChunk, ApiState> { chunk -> ApiState.Success(chunk.choices[0].delta.content ?: "") }
.map<ChatCompletionChunk, ApiState> { chunk -> ApiState.Success(chunk.choices[0].delta?.content ?: "") }
.catch { throwable -> emit(ApiState.Error(throwable.message ?: "Unknown error")) }
.onStart { emit(ApiState.Loading) }
.onCompletion { emit(ApiState.Done) }
Expand Down Expand Up @@ -126,7 +127,25 @@ class ChatRepositoryImpl @Inject constructor(
}

override suspend fun completeOllamaChat(question: Message, history: List<Message>): Flow<ApiState> {
TODO("Not yet implemented")
val platform = checkNotNull(settingRepository.fetchPlatforms().firstOrNull { it.name == ApiType.OLLAMA })
ollama = OpenAI(platform.token ?: "", host = OpenAIHost(baseUrl = "${platform.apiUrl}v1/"))

val generatedMessages = messageToOpenAIMessage(history + listOf(question))
val generatedMessageWithPrompt = listOf(
ChatMessage(role = ChatRole.System, content = platform.systemPrompt ?: ModelConstants.DEFAULT_PROMPT)
) + generatedMessages
val chatCompletionRequest = ChatCompletionRequest(
model = ModelId(platform.model ?: ""),
messages = generatedMessageWithPrompt,
temperature = platform.temperature?.toDouble(),
topP = platform.topP?.toDouble()
)

return ollama.chatCompletions(chatCompletionRequest)
.map<ChatCompletionChunk, ApiState> { chunk -> ApiState.Success(chunk.choices[0].delta?.content ?: "") }
.catch { throwable -> emit(ApiState.Error(throwable.message ?: "Unknown error")) }
.onStart { emit(ApiState.Loading) }
.onCompletion { emit(ApiState.Done) }
}

override suspend fun fetchChatList(): List<ChatRoom> = chatRoomDao.getChatRooms()
Expand Down

0 comments on commit 642b13e

Please sign in to comment.