Skip to content

Commit

Permalink
Merge pull request #33 from Taewan-P/feat/custom-api-url
Browse files Browse the repository at this point in the history
Support custom API address
  • Loading branch information
Taewan-P authored Aug 16, 2024
2 parents 83db122 + 0f9bcef commit 76fa830
Show file tree
Hide file tree
Showing 18 changed files with 208 additions and 34 deletions.
1 change: 1 addition & 0 deletions app/src/main/AndroidManifest.xml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
android:icon="@mipmap/ic_gpt_mobile"
android:label="@string/app_name"
android:supportsRtl="true"
android:usesCleartextTraffic="true"
tools:targetApi="upside_down_cake">
<activity
android:name=".presentation.ui.main.MainActivity"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,21 @@
package dev.chungjungsoo.gptmobile.data

import dev.chungjungsoo.gptmobile.data.model.ApiType

object ModelConstants {
// LinkedHashSet should be used to guarantee item order
val openaiModels = linkedSetOf("gpt-4o", "gpt-4-turbo", "gpt-4", "gpt-3.5-turbo")
val anthropicModels = linkedSetOf("claude-3-5-sonnet-20240620", "claude-3-opus-20240229", "claude-3-sonnet-20240229", "claude-3-haiku-20240307")
val googleModels = linkedSetOf("gemini-1.5-pro-latest", "gemini-1.5-flash-latest", "gemini-1.0-pro")
const val OPENAI_API_URL = "https://api.openai.com"
const val ANTHROPIC_API_URL = "https://api.anthropic.com"
const val GOOGLE_API_URL = "https://generativelanguage.googleapis.com"

fun getDefaultAPIUrl(apiType: ApiType) = when (apiType) {
ApiType.OPENAI -> OPENAI_API_URL
ApiType.ANTHROPIC -> ANTHROPIC_API_URL
ApiType.GOOGLE -> GOOGLE_API_URL
}

const val ANTHROPIC_MAXIMUM_TOKEN = 4096

Expand All @@ -13,7 +24,5 @@ object ModelConstants {
"You are familiar with various languages in the world. " +
"You are to answer my questions precisely. "

const val ANTHROPIC_PROMPT = "Your task is to answer my questions precisely."

const val GOOGLE_PROMPT = "Your task is to answer my questions precisely."
const val DEFAULT_PROMPT = "Your task is to answer my questions precisely."
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ interface SettingDataSource {
suspend fun updateDynamicTheme(theme: DynamicTheme)
suspend fun updateThemeMode(themeMode: ThemeMode)
suspend fun updateStatus(apiType: ApiType, status: Boolean)
suspend fun updateAPIUrl(apiType: ApiType, url: String)
suspend fun updateToken(apiType: ApiType, token: String)
suspend fun updateModel(apiType: ApiType, model: String)
suspend fun updateTemperature(apiType: ApiType, temperature: Float)
Expand All @@ -16,6 +17,7 @@ interface SettingDataSource {
suspend fun getDynamicTheme(): DynamicTheme?
suspend fun getThemeMode(): ThemeMode?
suspend fun getStatus(apiType: ApiType): Boolean?
suspend fun getAPIUrl(apiType: ApiType): String?
suspend fun getToken(apiType: ApiType): String?
suspend fun getModel(apiType: ApiType): String?
suspend fun getTemperature(apiType: ApiType): Float?
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ class SettingDataSourceImpl @Inject constructor(
ApiType.ANTHROPIC to booleanPreferencesKey("anthropic_status"),
ApiType.GOOGLE to booleanPreferencesKey("google_status")
)
private val apiUrlMap = mapOf(
ApiType.OPENAI to stringPreferencesKey("openai_url"),
ApiType.ANTHROPIC to stringPreferencesKey("anthropic_url"),
ApiType.GOOGLE to stringPreferencesKey("google_url")
)
private val apiTokenMap = mapOf(
ApiType.OPENAI to stringPreferencesKey("openai_token"),
ApiType.ANTHROPIC to stringPreferencesKey("anthropic_token"),
Expand Down Expand Up @@ -68,6 +73,12 @@ class SettingDataSourceImpl @Inject constructor(
}
}

override suspend fun updateAPIUrl(apiType: ApiType, url: String) {
dataStore.edit { pref ->
pref[apiUrlMap[apiType]!!] = url
}
}

override suspend fun updateToken(apiType: ApiType, token: String) {
dataStore.edit { pref ->
pref[apiTokenMap[apiType]!!] = token
Expand Down Expand Up @@ -118,6 +129,10 @@ class SettingDataSourceImpl @Inject constructor(
pref[apiStatusMap[apiType]!!]
}.first()

override suspend fun getAPIUrl(apiType: ApiType): String? = dataStore.data.map { pref ->
pref[apiUrlMap[apiType]!!]
}.first()

override suspend fun getToken(apiType: ApiType): String? = dataStore.data.map { pref ->
pref[apiTokenMap[apiType]!!]
}.first()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
package dev.chungjungsoo.gptmobile.data.dto

import dev.chungjungsoo.gptmobile.data.ModelConstants.getDefaultAPIUrl
import dev.chungjungsoo.gptmobile.data.model.ApiType

data class Platform(
val name: ApiType,
val selected: Boolean = false,
val enabled: Boolean = false,
val apiUrl: String = getDefaultAPIUrl(name),
val token: String? = null,
val model: String? = null,
val temperature: Float? = null,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@ import kotlinx.coroutines.flow.Flow

interface AnthropicAPI {
fun setToken(token: String?)
fun setAPIUrl(url: String)
fun streamChatMessage(messageRequest: MessageRequest): Flow<MessageResponseChunk>
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package dev.chungjungsoo.gptmobile.data.network

import dev.chungjungsoo.gptmobile.data.ModelConstants
import dev.chungjungsoo.gptmobile.data.dto.anthropic.request.MessageRequest
import dev.chungjungsoo.gptmobile.data.dto.anthropic.response.ErrorDetail
import dev.chungjungsoo.gptmobile.data.dto.anthropic.response.ErrorResponseChunk
Expand Down Expand Up @@ -32,17 +33,22 @@ class AnthropicAPIImpl @Inject constructor(
) : AnthropicAPI {

private var token: String? = null
private var apiUrl: String = ModelConstants.ANTHROPIC_API_URL

override fun setToken(token: String?) {
this.token = token
}

override fun setAPIUrl(url: String) {
this.apiUrl = url
}

override fun streamChatMessage(messageRequest: MessageRequest): Flow<MessageResponseChunk> {
val body = Json.encodeToJsonElement(messageRequest)

val builder = HttpRequestBuilder().apply {
method = HttpMethod.Post
url("${ANTHROPIC_CHAT_API}/v1/messages")
url("$apiUrl/v1/messages")
contentType(ContentType.Application.Json)
setBody(body)
accept(ContentType.Text.EventStream)
Expand Down Expand Up @@ -81,7 +87,6 @@ class AnthropicAPIImpl @Inject constructor(
}

companion object {
private const val ANTHROPIC_CHAT_API = "https://api.anthropic.com"
private const val STREAM_PREFIX = "data:"
private const val STREAM_END_TOKEN = "event: message_stop"
private const val API_KEY_HEADER = "x-api-key"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import com.aallam.openai.api.chat.ChatMessage
import com.aallam.openai.api.chat.ChatRole
import com.aallam.openai.api.model.ModelId
import com.aallam.openai.client.OpenAI
import com.aallam.openai.client.OpenAIHost
import com.google.ai.client.generativeai.GenerativeModel
import com.google.ai.client.generativeai.type.BlockThreshold
import com.google.ai.client.generativeai.type.Content
Expand Down Expand Up @@ -48,7 +49,7 @@ class ChatRepositoryImpl @Inject constructor(

override suspend fun completeOpenAIChat(question: Message, history: List<Message>): Flow<ApiState> {
val platform = checkNotNull(settingRepository.fetchPlatforms().firstOrNull { it.name == ApiType.OPENAI })
openAI = OpenAI(platform.token ?: "")
openAI = OpenAI(platform.token ?: "", host = OpenAIHost(baseUrl = platform.apiUrl))

val generatedMessages = messageToOpenAIMessage(history + listOf(question))
val generatedMessageWithPrompt = listOf(
Expand All @@ -71,13 +72,14 @@ class ChatRepositoryImpl @Inject constructor(
override suspend fun completeAnthropicChat(question: Message, history: List<Message>): Flow<ApiState> {
val platform = checkNotNull(settingRepository.fetchPlatforms().firstOrNull { it.name == ApiType.ANTHROPIC })
anthropic.setToken(platform.token)
anthropic.setAPIUrl(platform.apiUrl)

val generatedMessages = messageToAnthropicMessage(history + listOf(question))
val messageRequest = MessageRequest(
model = platform.model ?: "",
messages = generatedMessages,
maxTokens = ModelConstants.ANTHROPIC_MAXIMUM_TOKEN,
systemPrompt = platform.systemPrompt ?: ModelConstants.ANTHROPIC_PROMPT,
systemPrompt = platform.systemPrompt ?: ModelConstants.DEFAULT_PROMPT,
stream = true,
temperature = platform.temperature,
topP = platform.topP
Expand Down Expand Up @@ -105,7 +107,7 @@ class ChatRepositoryImpl @Inject constructor(
google = GenerativeModel(
modelName = platform.model ?: "",
apiKey = platform.token ?: "",
systemInstruction = content { text(platform.systemPrompt ?: ModelConstants.GOOGLE_PROMPT) },
systemInstruction = content { text(platform.systemPrompt ?: ModelConstants.DEFAULT_PROMPT) },
generationConfig = config,
safetySettings = listOf(
SafetySetting(HarmCategory.DANGEROUS_CONTENT, BlockThreshold.ONLY_HIGH),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,25 @@ class SettingRepositoryImpl @Inject constructor(

override suspend fun fetchPlatforms(): List<Platform> = ApiType.entries.map { apiType ->
val status = settingDataSource.getStatus(apiType)
val apiUrl = when (apiType) {
ApiType.OPENAI -> settingDataSource.getAPIUrl(apiType) ?: ModelConstants.OPENAI_API_URL
ApiType.ANTHROPIC -> settingDataSource.getAPIUrl(apiType) ?: ModelConstants.ANTHROPIC_API_URL
ApiType.GOOGLE -> settingDataSource.getAPIUrl(apiType) ?: ModelConstants.GOOGLE_API_URL
}
val token = settingDataSource.getToken(apiType)
val model = settingDataSource.getModel(apiType)
val temperature = settingDataSource.getTemperature(apiType)
val topP = settingDataSource.getTopP(apiType)
val systemPrompt = when (apiType) {
ApiType.OPENAI -> settingDataSource.getSystemPrompt(ApiType.OPENAI) ?: ModelConstants.OPENAI_PROMPT
ApiType.ANTHROPIC -> settingDataSource.getSystemPrompt(ApiType.ANTHROPIC) ?: ModelConstants.ANTHROPIC_PROMPT
ApiType.GOOGLE -> settingDataSource.getSystemPrompt(ApiType.GOOGLE) ?: ModelConstants.GOOGLE_PROMPT
ApiType.ANTHROPIC -> settingDataSource.getSystemPrompt(ApiType.ANTHROPIC) ?: ModelConstants.DEFAULT_PROMPT
ApiType.GOOGLE -> settingDataSource.getSystemPrompt(ApiType.GOOGLE) ?: ModelConstants.DEFAULT_PROMPT
}

Platform(
name = apiType,
enabled = status ?: false,
apiUrl = apiUrl,
token = token,
model = model,
temperature = temperature,
Expand All @@ -44,26 +50,13 @@ class SettingRepositoryImpl @Inject constructor(
override suspend fun updatePlatforms(platforms: List<Platform>) {
platforms.forEach { platform ->
settingDataSource.updateStatus(platform.name, platform.enabled)
settingDataSource.updateAPIUrl(platform.name, platform.apiUrl)

if (platform.token != null) {
settingDataSource.updateToken(platform.name, platform.token)
}

if (platform.model != null) {
settingDataSource.updateModel(platform.name, platform.model)
}

if (platform.temperature != null) {
settingDataSource.updateTemperature(platform.name, platform.temperature)
}

if (platform.topP != null) {
settingDataSource.updateTopP(platform.name, platform.topP)
}

if (platform.systemPrompt != null) {
settingDataSource.updateSystemPrompt(platform.name, platform.systemPrompt.trim())
}
platform.token?.let { settingDataSource.updateToken(platform.name, it) }
platform.model?.let { settingDataSource.updateModel(platform.name, it) }
platform.temperature?.let { settingDataSource.updateTemperature(platform.name, it) }
platform.topP?.let { settingDataSource.updateTopP(platform.name, it) }
platform.systemPrompt?.let { settingDataSource.updateSystemPrompt(platform.name, it.trim()) }
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ fun ChatScreen(
.horizontalScroll(chatBubbleScrollStates[(key - 1) / 2])
) {
Spacer(modifier = Modifier.width(8.dp))
groupedMessages[key]!!.sortedByDescending { it.platformType }.forEach { m ->
groupedMessages[key]!!.sortedBy { it.platformType }.forEach { m ->
m.platformType?.let { apiType ->
OpponentChatBubble(
modifier = Modifier
Expand Down Expand Up @@ -200,7 +200,7 @@ fun ChatScreen(
.horizontalScroll(chatBubbleScrollStates[(latestMessageIndex + 1) / 2])
) {
Spacer(modifier = Modifier.width(8.dp))
chatViewModel.enabledPlatformsInChat.sortedDescending().forEach { apiType ->
chatViewModel.enabledPlatformsInChat.sorted().forEach { apiType ->
val message = when (apiType) {
ApiType.OPENAI -> openAIMessage
ApiType.ANTHROPIC -> anthropicMessage
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package dev.chungjungsoo.gptmobile.presentation.ui.setting

import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.widthIn
Expand All @@ -27,6 +28,7 @@ import androidx.compose.ui.unit.dp
import androidx.compose.ui.window.DialogProperties
import dev.chungjungsoo.gptmobile.R
import dev.chungjungsoo.gptmobile.data.ModelConstants.anthropicModels
import dev.chungjungsoo.gptmobile.data.ModelConstants.getDefaultAPIUrl
import dev.chungjungsoo.gptmobile.data.ModelConstants.googleModels
import dev.chungjungsoo.gptmobile.data.ModelConstants.openaiModels
import dev.chungjungsoo.gptmobile.data.model.ApiType
Expand All @@ -37,8 +39,35 @@ import dev.chungjungsoo.gptmobile.util.generateGoogleModelList
import dev.chungjungsoo.gptmobile.util.generateOpenAIModelList
import dev.chungjungsoo.gptmobile.util.getPlatformAPILabelResources
import dev.chungjungsoo.gptmobile.util.getPlatformHelpLinkResources
import dev.chungjungsoo.gptmobile.util.isValidUrl
import kotlin.math.roundToInt

@Composable
fun APIUrlDialog(
dialogState: SettingViewModel.DialogState,
apiType: ApiType,
initialValue: String,
settingViewModel: SettingViewModel
) {
if (dialogState.isApiUrlDialogOpen) {
APIUrlDialog(
apiType = apiType,
initialValue = initialValue,
onDismissRequest = settingViewModel::closeApiUrlDialog,
onResetRequest = {
settingViewModel.updateURL(apiType, getDefaultAPIUrl(apiType))
settingViewModel.savePlatformSettings()
settingViewModel.closeApiUrlDialog()
},
onConfirmRequest = { apiUrl ->
settingViewModel.updateURL(apiType, apiUrl)
settingViewModel.savePlatformSettings()
settingViewModel.closeApiUrlDialog()
}
)
}
}

@Composable
fun APIKeyDialog(
dialogState: SettingViewModel.DialogState,
Expand Down Expand Up @@ -136,6 +165,64 @@ fun SystemPromptDialog(
}
}

@Composable
private fun APIUrlDialog(
apiType: ApiType,
initialValue: String,
onDismissRequest: () -> Unit,
onResetRequest: () -> Unit,
onConfirmRequest: (url: String) -> Unit
) {
var apiUrl by remember { mutableStateOf(initialValue) }
val configuration = LocalConfiguration.current

AlertDialog(
properties = DialogProperties(usePlatformDefaultWidth = false),
modifier = Modifier.widthIn(max = configuration.screenWidthDp.dp - 40.dp),
title = { Text(text = stringResource(R.string.api_url)) },
text = {
OutlinedTextField(
modifier = Modifier
.fillMaxWidth()
.padding(horizontal = 20.dp, vertical = 16.dp),
value = apiUrl,
isError = apiUrl.isValidUrl().not(),
onValueChange = { apiUrl = it },
label = {
Text(stringResource(R.string.api_url))
},
supportingText = {
if (apiUrl.isValidUrl().not()) {
Text(text = stringResource(R.string.invalid_api_url))
}
}
)
},
onDismissRequest = onDismissRequest,
confirmButton = {
TextButton(
enabled = apiUrl.isNotBlank() && apiUrl.isValidUrl(),
onClick = { onConfirmRequest(apiUrl) }
) {
Text(stringResource(R.string.confirm))
}
},
dismissButton = {
Row {
TextButton(
modifier = Modifier.padding(end = 8.dp),
onClick = onResetRequest
) {
Text(stringResource(R.string.reset))
}
TextButton(onClick = onDismissRequest) {
Text(stringResource(R.string.cancel))
}
}
}
)
}

@Composable
private fun APIKeyDialog(
apiType: ApiType,
Expand Down
Loading

0 comments on commit 76fa830

Please sign in to comment.