Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Ollama Support #47

Merged
merged 15 commits into from
Oct 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .idea/gradle.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions .idea/kotlinc.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion .idea/misc.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 13 additions & 0 deletions .idea/runConfigurations.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 4 additions & 3 deletions app/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ plugins {
alias(libs.plugins.android.application)
alias(libs.plugins.jetbrains.kotlin.android)
alias(libs.plugins.android.hilt)
alias(libs.plugins.compose.compiler)
alias(libs.plugins.kotlin.ksp)
alias(libs.plugins.kotlin.parcelize)
alias(libs.plugins.auto.license)
Expand Down Expand Up @@ -51,9 +52,9 @@ android {
buildFeatures {
compose = true
}
composeOptions {
kotlinCompilerExtensionVersion = "1.5.13" // Make sure to update this when Kotlin version is updated
}
// composeOptions {
// kotlinCompilerExtensionVersion = "1.5.13" // Make sure to update this when Kotlin version is updated
// }
packaging {
resources {
excludes += "/META-INF/{AL2.0,LGPL2.1}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,17 @@ object ModelConstants {
val openaiModels = linkedSetOf("gpt-4o", "gpt-4-turbo", "gpt-4", "gpt-3.5-turbo")
val anthropicModels = linkedSetOf("claude-3-5-sonnet-20240620", "claude-3-opus-20240229", "claude-3-sonnet-20240229", "claude-3-haiku-20240307")
val googleModels = linkedSetOf("gemini-1.5-pro-latest", "gemini-1.5-flash-latest", "gemini-1.0-pro")
const val OPENAI_API_URL = "https://api.openai.com"
const val ANTHROPIC_API_URL = "https://api.anthropic.com"
val ollamaModels = linkedSetOf<String>()

const val OPENAI_API_URL = "https://api.openai.com/v1/"
const val ANTHROPIC_API_URL = "https://api.anthropic.com/"
const val GOOGLE_API_URL = "https://generativelanguage.googleapis.com"

fun getDefaultAPIUrl(apiType: ApiType) = when (apiType) {
ApiType.OPENAI -> OPENAI_API_URL
ApiType.ANTHROPIC -> ANTHROPIC_API_URL
ApiType.GOOGLE -> GOOGLE_API_URL
ApiType.OLLAMA -> ""
}

const val ANTHROPIC_MAXIMUM_TOKEN = 4096
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,37 +20,44 @@ class SettingDataSourceImpl @Inject constructor(
private val apiStatusMap = mapOf(
ApiType.OPENAI to booleanPreferencesKey("openai_status"),
ApiType.ANTHROPIC to booleanPreferencesKey("anthropic_status"),
ApiType.GOOGLE to booleanPreferencesKey("google_status")
ApiType.GOOGLE to booleanPreferencesKey("google_status"),
ApiType.OLLAMA to booleanPreferencesKey("ollama_status")
)
private val apiUrlMap = mapOf(
ApiType.OPENAI to stringPreferencesKey("openai_url"),
ApiType.ANTHROPIC to stringPreferencesKey("anthropic_url"),
ApiType.GOOGLE to stringPreferencesKey("google_url")
ApiType.GOOGLE to stringPreferencesKey("google_url"),
ApiType.OLLAMA to stringPreferencesKey("ollama_url")
)
private val apiTokenMap = mapOf(
ApiType.OPENAI to stringPreferencesKey("openai_token"),
ApiType.ANTHROPIC to stringPreferencesKey("anthropic_token"),
ApiType.GOOGLE to stringPreferencesKey("google_token")
ApiType.GOOGLE to stringPreferencesKey("google_token"),
ApiType.OLLAMA to stringPreferencesKey("ollama_token")
)
private val apiModelMap = mapOf(
ApiType.OPENAI to stringPreferencesKey("openai_model"),
ApiType.ANTHROPIC to stringPreferencesKey("anthropic_model"),
ApiType.GOOGLE to stringPreferencesKey("google_model")
ApiType.GOOGLE to stringPreferencesKey("google_model"),
ApiType.OLLAMA to stringPreferencesKey("ollama_model")
)
private val apiTemperatureMap = mapOf(
ApiType.OPENAI to floatPreferencesKey("openai_temperature"),
ApiType.ANTHROPIC to floatPreferencesKey("anthropic_temperature"),
ApiType.GOOGLE to floatPreferencesKey("google_temperature")
ApiType.GOOGLE to floatPreferencesKey("google_temperature"),
ApiType.OLLAMA to floatPreferencesKey("ollama_temperature")
)
private val apiTopPMap = mapOf(
ApiType.OPENAI to floatPreferencesKey("openai_top_p"),
ApiType.ANTHROPIC to floatPreferencesKey("anthropic_top_p"),
ApiType.GOOGLE to floatPreferencesKey("google_top_p")
ApiType.GOOGLE to floatPreferencesKey("google_top_p"),
ApiType.OLLAMA to floatPreferencesKey("ollama_top_p")
)
private val apiSystemPromptMap = mapOf(
ApiType.OPENAI to stringPreferencesKey("openai_system_prompt"),
ApiType.ANTHROPIC to stringPreferencesKey("anthropic_system_prompt"),
ApiType.GOOGLE to stringPreferencesKey("google_system_prompt")
ApiType.GOOGLE to stringPreferencesKey("google_system_prompt"),
ApiType.OLLAMA to stringPreferencesKey("ollama_system_prompt")
)
private val dynamicThemeKey = intPreferencesKey("dynamic_mode")
private val themeModeKey = intPreferencesKey("theme_mode")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@ package dev.chungjungsoo.gptmobile.data.model
enum class ApiType {
OPENAI,
ANTHROPIC,
GOOGLE
GOOGLE,
OLLAMA
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class AnthropicAPIImpl @Inject constructor(

val builder = HttpRequestBuilder().apply {
method = HttpMethod.Post
url("$apiUrl/v1/messages")
if (apiUrl.endsWith("/")) url("${apiUrl}v1/messages") else url("$apiUrl/v1/messages")
contentType(ContentType.Application.Json)
setBody(body)
accept(ContentType.Text.EventStream)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ interface ChatRepository {
suspend fun completeOpenAIChat(question: Message, history: List<Message>): Flow<ApiState>
suspend fun completeAnthropicChat(question: Message, history: List<Message>): Flow<ApiState>
suspend fun completeGoogleChat(question: Message, history: List<Message>): Flow<ApiState>
suspend fun completeOllamaChat(question: Message, history: List<Message>): Flow<ApiState>
suspend fun fetchChatList(): List<ChatRoom>
suspend fun fetchMessages(chatId: Int): List<Message>
suspend fun updateChatTitle(chatRoom: ChatRoom, title: String)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class ChatRepositoryImpl @Inject constructor(

private lateinit var openAI: OpenAI
private lateinit var google: GenerativeModel
private lateinit var ollama: OpenAI

override suspend fun completeOpenAIChat(question: Message, history: List<Message>): Flow<ApiState> {
val platform = checkNotNull(settingRepository.fetchPlatforms().firstOrNull { it.name == ApiType.OPENAI })
Expand All @@ -63,7 +64,7 @@ class ChatRepositoryImpl @Inject constructor(
)

return openAI.chatCompletions(chatCompletionRequest)
.map<ChatCompletionChunk, ApiState> { chunk -> ApiState.Success(chunk.choices[0].delta.content ?: "") }
.map<ChatCompletionChunk, ApiState> { chunk -> ApiState.Success(chunk.choices[0].delta?.content ?: "") }
.catch { throwable -> emit(ApiState.Error(throwable.message ?: "Unknown error")) }
.onStart { emit(ApiState.Loading) }
.onCompletion { emit(ApiState.Done) }
Expand Down Expand Up @@ -125,6 +126,28 @@ class ChatRepositoryImpl @Inject constructor(
.onCompletion { emit(ApiState.Done) }
}

override suspend fun completeOllamaChat(question: Message, history: List<Message>): Flow<ApiState> {
val platform = checkNotNull(settingRepository.fetchPlatforms().firstOrNull { it.name == ApiType.OLLAMA })
ollama = OpenAI(platform.token ?: "", host = OpenAIHost(baseUrl = "${platform.apiUrl}v1/"))

val generatedMessages = messageToOpenAIMessage(history + listOf(question))
val generatedMessageWithPrompt = listOf(
ChatMessage(role = ChatRole.System, content = platform.systemPrompt ?: ModelConstants.DEFAULT_PROMPT)
) + generatedMessages
val chatCompletionRequest = ChatCompletionRequest(
model = ModelId(platform.model ?: ""),
messages = generatedMessageWithPrompt,
temperature = platform.temperature?.toDouble(),
topP = platform.topP?.toDouble()
)

return ollama.chatCompletions(chatCompletionRequest)
.map<ChatCompletionChunk, ApiState> { chunk -> ApiState.Success(chunk.choices[0].delta?.content ?: "") }
.catch { throwable -> emit(ApiState.Error(throwable.message ?: "Unknown error")) }
.onStart { emit(ApiState.Loading) }
.onCompletion { emit(ApiState.Done) }
}

override suspend fun fetchChatList(): List<ChatRoom> = chatRoomDao.getChatRooms()

override suspend fun fetchMessages(chatId: Int): List<Message> = messageDao.loadMessages(chatId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class SettingRepositoryImpl @Inject constructor(
ApiType.OPENAI -> settingDataSource.getAPIUrl(apiType) ?: ModelConstants.OPENAI_API_URL
ApiType.ANTHROPIC -> settingDataSource.getAPIUrl(apiType) ?: ModelConstants.ANTHROPIC_API_URL
ApiType.GOOGLE -> settingDataSource.getAPIUrl(apiType) ?: ModelConstants.GOOGLE_API_URL
ApiType.OLLAMA -> settingDataSource.getAPIUrl(apiType) ?: ""
}
val token = settingDataSource.getToken(apiType)
val model = settingDataSource.getModel(apiType)
Expand All @@ -28,11 +29,12 @@ class SettingRepositoryImpl @Inject constructor(
ApiType.OPENAI -> settingDataSource.getSystemPrompt(ApiType.OPENAI) ?: ModelConstants.OPENAI_PROMPT
ApiType.ANTHROPIC -> settingDataSource.getSystemPrompt(ApiType.ANTHROPIC) ?: ModelConstants.DEFAULT_PROMPT
ApiType.GOOGLE -> settingDataSource.getSystemPrompt(ApiType.GOOGLE) ?: ModelConstants.DEFAULT_PROMPT
ApiType.OLLAMA -> settingDataSource.getSystemPrompt(ApiType.OLLAMA) ?: ModelConstants.DEFAULT_PROMPT
}

Platform(
name = apiType,
enabled = status ?: false,
enabled = status == true,
apiUrl = apiUrl,
token = token,
model = model,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import dev.chungjungsoo.gptmobile.presentation.ui.setting.SettingScreen
import dev.chungjungsoo.gptmobile.presentation.ui.setting.SettingViewModel
import dev.chungjungsoo.gptmobile.presentation.ui.setup.SelectModelScreen
import dev.chungjungsoo.gptmobile.presentation.ui.setup.SelectPlatformScreen
import dev.chungjungsoo.gptmobile.presentation.ui.setup.SetupAPIUrlScreen
import dev.chungjungsoo.gptmobile.presentation.ui.setup.SetupCompleteScreen
import dev.chungjungsoo.gptmobile.presentation.ui.setup.SetupViewModel
import dev.chungjungsoo.gptmobile.presentation.ui.setup.TokenInputScreen
Expand Down Expand Up @@ -117,6 +118,31 @@ fun NavGraphBuilder.setupNavigation(
onBackAction = { navController.navigateUp() }
)
}
composable(route = Route.OLLAMA_MODEL_SELECT) {
val parentEntry = remember(it) {
navController.getBackStackEntry(Route.SETUP_ROUTE)
}
val setupViewModel: SetupViewModel = hiltViewModel(parentEntry)
SelectModelScreen(
setupViewModel = setupViewModel,
currentRoute = Route.OLLAMA_MODEL_SELECT,
platformType = ApiType.OLLAMA,
onNavigate = { route -> navController.navigate(route) },
onBackAction = { navController.navigateUp() }
)
}
composable(route = Route.OLLAMA_API_ADDRESS) {
val parentEntry = remember(it) {
navController.getBackStackEntry(Route.SETUP_ROUTE)
}
val setupViewModel: SetupViewModel = hiltViewModel(parentEntry)
SetupAPIUrlScreen(
setupViewModel = setupViewModel,
currentRoute = Route.OLLAMA_API_ADDRESS,
onNavigate = { route -> navController.navigate(route) },
onBackAction = { navController.navigateUp() }
)
}
composable(route = Route.SETUP_COMPLETE) {
val parentEntry = remember(it) {
navController.getBackStackEntry(Route.SETUP_ROUTE)
Expand Down Expand Up @@ -188,6 +214,7 @@ fun NavGraphBuilder.settingNavigation(navController: NavHostController) {
ApiType.OPENAI -> navController.navigate(Route.OPENAI_SETTINGS)
ApiType.ANTHROPIC -> navController.navigate(Route.ANTHROPIC_SETTINGS)
ApiType.GOOGLE -> navController.navigate(Route.GOOGLE_SETTINGS)
ApiType.OLLAMA -> navController.navigate(Route.OLLAMA_SETTINGS)
}
},
onNavigateToAboutPage = { navController.navigate(Route.ABOUT_PAGE) }
Expand Down Expand Up @@ -223,6 +250,16 @@ fun NavGraphBuilder.settingNavigation(navController: NavHostController) {
apiType = ApiType.GOOGLE
) { navController.navigateUp() }
}
composable(Route.OLLAMA_SETTINGS) {
val parentEntry = remember(it) {
navController.getBackStackEntry(Route.SETTING_ROUTE)
}
val settingViewModel: SettingViewModel = hiltViewModel(parentEntry)
PlatformSettingScreen(
settingViewModel = settingViewModel,
apiType = ApiType.OLLAMA
) { navController.navigateUp() }
}
composable(Route.ABOUT_PAGE) {
AboutScreen(
onNavigationClick = { navController.navigateUp() },
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ object Route {
const val OPENAI_MODEL_SELECT = "openai_model_select"
const val ANTHROPIC_MODEL_SELECT = "anthropic_model_select"
const val GOOGLE_MODEL_SELECT = "google_model_select"
const val OLLAMA_MODEL_SELECT = "ollama_model_select"
const val OLLAMA_API_ADDRESS = "ollama_api_address"
const val SETUP_COMPLETE = "setup_complete"

const val CHAT_LIST = "chat_list"
Expand All @@ -20,6 +22,7 @@ object Route {
const val OPENAI_SETTINGS = "openai_settings"
const val ANTHROPIC_SETTINGS = "anthropic_settings"
const val GOOGLE_SETTINGS = "google_settings"
const val OLLAMA_SETTINGS = "ollama_settings"
const val ABOUT_PAGE = "about"
const val LICENSE = "license"
}
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,18 @@ fun ChatScreen(
val messages by chatViewModel.messages.collectManagedState()
val question by chatViewModel.question.collectManagedState()
val appEnabledPlatforms by chatViewModel.enabledPlatformsInApp.collectManagedState()

val openaiLoadingState by chatViewModel.openaiLoadingState.collectManagedState()
val anthropicLoadingState by chatViewModel.anthropicLoadingState.collectManagedState()
val googleLoadingState by chatViewModel.googleLoadingState.collectManagedState()
val ollamaLoadingState by chatViewModel.ollamaLoadingState.collectManagedState()

val userMessage by chatViewModel.userMessage.collectManagedState()

val openAIMessage by chatViewModel.openAIMessage.collectManagedState()
val anthropicMessage by chatViewModel.anthropicMessage.collectManagedState()
val googleMessage by chatViewModel.googleMessage.collectManagedState()
val ollamaMessage by chatViewModel.ollamaMessage.collectManagedState()

val canUseChat = (chatViewModel.enabledPlatformsInChat.toSet() - appEnabledPlatforms.toSet()).isEmpty()
val groupedMessages = remember(messages) { groupMessages(messages) }
Expand Down Expand Up @@ -205,12 +210,14 @@ fun ChatScreen(
ApiType.OPENAI -> openAIMessage
ApiType.ANTHROPIC -> anthropicMessage
ApiType.GOOGLE -> googleMessage
ApiType.OLLAMA -> ollamaMessage
}

val loadingState = when (apiType) {
ApiType.OPENAI -> openaiLoadingState
ApiType.ANTHROPIC -> anthropicLoadingState
ApiType.GOOGLE -> googleLoadingState
ApiType.OLLAMA -> ollamaLoadingState
}

OpponentChatBubble(
Expand Down
Loading