Skip to content

Commit

Permalink
GCP Pipeline Jobs (#315)
Browse files Browse the repository at this point in the history
* GCP Pipeline jobs API - List

* Finishes the Pipeline Endpoints implementation

* Updates to conversational

* spotless

* add support for ktor's Closeable to AutoClose; extract common code of GCP HttpClient

* extract common config code for accessing GCP web api (token, location, projectId)

---------

Co-authored-by: ron <ron.spannagel@47deg.com>
  • Loading branch information
fedefernandez and Intex32 authored Aug 30, 2023
1 parent fd12b91 commit 5e84a22
Show file tree
Hide file tree
Showing 10 changed files with 232 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package com.xebia.functional.xef.conversation

import arrow.atomic.Atomic
import arrow.atomic.update
import io.ktor.utils.io.core.*

/**
* AutoClose offers DSL style API for creating parent-child relationships of AutoCloseable
Expand Down Expand Up @@ -34,5 +35,15 @@ fun autoClose(): AutoClose =
}
}

/** integration to Ktor's [Closeable] */
fun <A : Closeable> AutoClose.autoClose(closeable: A): A {
val wrapper =
object : AutoCloseable {
override fun close() = closeable.close()
}
autoClose(wrapper)
return closeable
}

private fun Throwable?.add(other: Throwable?): Throwable? =
this?.apply { other?.let { addSuppressed(it) } } ?: other
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ package com.xebia.functional.xef.conversation.gpc

import arrow.core.nonEmptyListOf
import com.xebia.functional.xef.AIError
import com.xebia.functional.xef.conversation.autoClose
import com.xebia.functional.xef.conversation.llm.openai.OpenAI
import com.xebia.functional.xef.env.getenv
import com.xebia.functional.xef.gcp.GcpChat
import com.xebia.functional.xef.gcp.GcpConfig
import com.xebia.functional.xef.gcp.GcpEmbeddings
import com.xebia.functional.xef.llm.models.embeddings.RequestConfig
import com.xebia.functional.xef.prompt.Prompt
Expand All @@ -16,11 +16,9 @@ suspend fun main() {
getenv("GCP_TOKEN") ?: throw AIError.Env.GCP(nonEmptyListOf("missing GCP_TOKEN env var"))

val gcp =
GcpChat("us-central1-aiplatform.googleapis.com", "xefdemo", "codechat-bison@001", token)
.let(::autoClose)
GcpChat("codechat-bison@001", GcpConfig(token, "xefdemo", "us-central1")).let(::autoClose)
val gcpEmbeddingModel =
GcpChat("us-central1-aiplatform.googleapis.com", "xefdemo", "textembedding-gecko", token)
.let(::autoClose)
GcpChat("codechat-bison@001", GcpConfig(token, "xefdemo", "us-central1")).let(::autoClose)

val embeddingResult =
GcpEmbeddings(gcpEmbeddingModel)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package com.xebia.functional.xef.conversation.gpc

import com.xebia.functional.gpt4all.conversation
import com.xebia.functional.xef.env.getenv
import com.xebia.functional.xef.gcp.GcpConfig
import com.xebia.functional.xef.gcp.pipelines.GcpPipelinesClient

suspend fun main() {
conversation {
val token = getenv("GCP_TOKEN") ?: error("missing gcp token")
val pipelineClient = autoClose(GcpPipelinesClient(GcpConfig(token, "xefdemo", "us-central1")))
val answer = pipelineClient.list()
println("\n🤖 $answer")
}
}
4 changes: 3 additions & 1 deletion gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ arrow = "1.2.0"
arrowGradle = "0.12.0-rc.5"
kotlin = "1.9.10"
kotlinx-json = "1.5.1"
kotlinx-datetime = "0.4.0"
ktor = "2.3.3"
spotless = "6.20.0"
okio = "3.5.0"
Expand Down Expand Up @@ -56,6 +57,7 @@ kotlinx-serialization-hocon = { module = "org.jetbrains.kotlinx:kotlinx-serializ
kotlinx-coroutines = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-core", version.ref="kotlinx-coroutines" }
kotlinx-coroutines-reactive = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-reactive", version.ref="kotlinx-coroutines" }
kotlinx-coroutines-jdk8 = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-jdk8", version.ref="kotlinx-coroutines" }
kotlinx-datetime = { module = "org.jetbrains.kotlinx:kotlinx-datetime", version.ref = "kotlinx-datetime" }
ktor-utils = { module = "io.ktor:ktor-utils", version.ref = "ktor" }
ktor-http = { module = "io.ktor:ktor-http", version.ref = "ktor" }
ktor-client ={ module = "io.ktor:ktor-client-core", version.ref = "ktor" }
Expand Down Expand Up @@ -142,4 +144,4 @@ semver-gradle = { id="com.javiersc.semver", version.ref="semverGradle" }
suspend-transform-plugin = { id="love.forte.plugin.suspend-transform", version.ref="suspend-transform" }
resources = { id="com.goncalossilva.resources", version.ref="resources-kmp" }
detekt = { id="io.gitlab.arturbosch.detekt", version.ref="detekt"}
node-gradle = { id = "com.github.node-gradle.node", version.ref = "node-gradle" }
node-gradle = { id = "com.github.node-gradle.node", version.ref = "node-gradle" }
1 change: 1 addition & 0 deletions integrations/gcp/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ kotlin {
api(projects.xefCore)
implementation(libs.bundles.ktor.client)
implementation(libs.uuid)
implementation(libs.kotlinx.datetime)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@ import kotlinx.uuid.UUID
import kotlinx.uuid.generateUUID

@OptIn(ExperimentalStdlibApi::class)
class GcpChat(apiEndpoint: String, projectId: String, modelId: String, token: String) :
Chat, Completion, AutoCloseable, Embeddings {
private val client: GcpClient = GcpClient(apiEndpoint, projectId, modelId, token)
class GcpChat(modelId: String, config: GcpConfig) : Chat, Completion, AutoCloseable, Embeddings {
private val client: GcpClient = GcpClient(modelId, config)

override val name: String = client.modelId
override val modelType: ModelType =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,46 +4,30 @@ import com.xebia.functional.xef.AIError
import com.xebia.functional.xef.conversation.AutoClose
import com.xebia.functional.xef.conversation.autoClose
import com.xebia.functional.xef.llm.models.embeddings.EmbeddingRequest
import io.ktor.client.*
import io.ktor.client.HttpClient
import io.ktor.client.call.*
import io.ktor.client.call.body
import io.ktor.client.plugins.*
import io.ktor.client.plugins.contentnegotiation.ContentNegotiation
import io.ktor.client.request.*
import io.ktor.client.request.header
import io.ktor.client.request.post
import io.ktor.client.request.setBody
import io.ktor.client.statement.*
import io.ktor.client.statement.bodyAsText
import io.ktor.http.*
import io.ktor.http.ContentType
import io.ktor.http.HttpStatusCode
import io.ktor.http.contentType
import io.ktor.http.isSuccess
import io.ktor.serialization.kotlinx.json.json
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable
import kotlinx.serialization.json.Json

@OptIn(ExperimentalStdlibApi::class)
class GcpClient(
private val apiEndpoint: String,
private val projectId: String,
val modelId: String,
private val token: String
) : AutoCloseable, AutoClose by autoClose() {
private val http: HttpClient = HttpClient {
install(HttpTimeout) {
requestTimeoutMillis = 60_000
connectTimeoutMillis = 60_000
}
install(HttpRequestRetry)
install(ContentNegotiation) {
json(
Json {
encodeDefaults = false
isLenient = true
ignoreUnknownKeys = true
}
)
}
}
private val config: GcpConfig,
) : AutoClose by autoClose() {
private val http: HttpClient = jsonHttpClient()

@Serializable
private data class Prompt(val instances: List<Instance>, val parameters: Parameters? = null)
Expand Down Expand Up @@ -101,9 +85,9 @@ class GcpClient(
)
val response =
http.post(
"https://$apiEndpoint/v1/projects/$projectId/locations/us-central1/publishers/google/models/$modelId:predict"
"https://${config.location}-aiplatform.googleapis.com/v1/projects/${config.projectId}/locations/us-central1/publishers/google/models/$modelId:predict"
) {
header("Authorization", "Bearer $token")
header("Authorization", "Bearer ${config.token}")
contentType(ContentType.Application.Json)
setBody(body)
}
Expand Down Expand Up @@ -153,9 +137,9 @@ class GcpClient(
)
val response =
http.post(
"https://$apiEndpoint/v1/projects/$projectId/locations/us-central1/publishers/google/models/$modelId:predict"
"https://${config.location}-aiplatform.googleapis.com/v1/projects/${config.projectId}/locations/${config.location}/publishers/google/models/$modelId:predict"
) {
header("Authorization", "Bearer $token")
header("Authorization", "Bearer ${config.token}")
contentType(ContentType.Application.Json)
setBody(body)
}
Expand All @@ -168,8 +152,4 @@ class GcpClient(

class GcpClientException(val httpStatusCode: HttpStatusCode, val error: String) :
IllegalStateException("$httpStatusCode: $error")

override fun close() {
http.close()
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package com.xebia.functional.xef.gcp

data class GcpConfig(
val token: String,
val projectId: String,
/** https://cloud.google.com/vertex-ai/docs/general/locations */
val location: String, // Supported us-central1 or europe-west4
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package com.xebia.functional.xef.gcp

import com.xebia.functional.xef.conversation.AutoClose
import com.xebia.functional.xef.conversation.autoClose
import io.ktor.client.*
import io.ktor.client.plugins.*
import io.ktor.client.plugins.contentnegotiation.*
import io.ktor.serialization.kotlinx.json.*
import kotlinx.serialization.json.Json

/** default [HttpClient] to access GCP models using JSON */
internal fun AutoClose.jsonHttpClient(): HttpClient =
HttpClient {
install(HttpTimeout) {
requestTimeoutMillis = 60_000
connectTimeoutMillis = 60_000
}
install(HttpRequestRetry)
install(ContentNegotiation) {
json(
Json {
encodeDefaults = false
isLenient = true
ignoreUnknownKeys = true
}
)
}
}
.let(::autoClose)
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
package com.xebia.functional.xef.gcp.pipelines

import com.xebia.functional.xef.conversation.AutoClose
import com.xebia.functional.xef.conversation.autoClose
import com.xebia.functional.xef.gcp.GcpConfig
import com.xebia.functional.xef.gcp.jsonHttpClient
import io.ktor.client.*
import io.ktor.client.call.*
import io.ktor.client.request.*
import io.ktor.client.statement.*
import io.ktor.http.*
import kotlinx.datetime.Instant
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable

class GcpPipelinesClient(
private val config: GcpConfig,
) : AutoClose by autoClose() {
private val http: HttpClient = jsonHttpClient()

// https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.pipelineJobs/list#google.cloud.aiplatform.v1.PipelineService.ListPipelineJobs
@Serializable
private data class ListPipelineJobs(
val pipelineJobs: List<PipelineJob>? = null,
val nextPageToken: String? = null
)

@Serializable
enum class PipelineState {
PIPELINE_STATE_UNSPECIFIED,
PIPELINE_STATE_QUEUED,
PIPELINE_STATE_PENDING,
PIPELINE_STATE_RUNNING,
PIPELINE_STATE_SUCCEEDED,
PIPELINE_STATE_FAILED,
PIPELINE_STATE_CANCELLING,
PIPELINE_STATE_CANCELLED,
PIPELINE_STATE_PAUSED
}

// https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.pipelineJobs#PipelineJob
@Serializable
data class PipelineJob(
val name: String,
val displayName: String,
val createTime: Instant,
val startTime: Instant?,
val endTime: Instant?,
val updateTime: Instant,
val state: PipelineState
)

@Serializable
data class CreatePipelineJob(
val displayName: String,
val runtimeConfig: RuntimeConfig,
val templateUri: String
)

@Serializable
data class RuntimeConfig(val gcsOutputDirectory: String, val parameterValues: ParameterValues)

@Serializable
data class ParameterValues(
@SerialName("project") val project: String,
@SerialName("model_display_name") val modelDisplayName: String,
@SerialName("dataset_uri") val datasetUri: String,
@SerialName("location") val location: String,
@SerialName("large_model_reference") val largeModelReference: String,
@SerialName("train_steps") val trainSteps: String,
@SerialName("learning_rate_multiplier") val learningRateMultiplier: String
)

@Serializable
data class Operation(val name: String, val done: Boolean, val error: OperationStatus?)

@Serializable data class OperationStatus(val code: Int, val message: String)

suspend fun list(): List<PipelineJob> {
val response =
http.get(
"https://${config.location}-aiplatform.googleapis.com/v1/projects/${config.projectId}/locations/${config.location}/pipelineJobs"
) {
header("Authorization", "Bearer ${config.token}")
contentType(ContentType.Application.Json)
}

return if (response.status.isSuccess()) response.body<ListPipelineJobs>().pipelineJobs.orEmpty()
else throw GcpClientException(response.status, response.bodyAsText())
}

suspend fun get(pipelineJobName: String): PipelineJob? {
val response =
http.get(
"https://${config.location}-aiplatform.googleapis.com/v1/projects/${config.projectId}/locations/${config.location}/pipelineJobs/$pipelineJobName"
) {
header("Authorization", "Bearer ${config.token}")
contentType(ContentType.Application.Json)
}

return if (response.status.isSuccess()) response.body<PipelineJob?>()
else throw GcpClientException(response.status, response.bodyAsText())
}

suspend fun create(pipelineJobId: String?, pipelineJob: CreatePipelineJob): PipelineJob? {
val response =
http.post(
"https://${config.location}-aiplatform.googleapis.com/v1/projects/${config.projectId}/locations/${config.location}/pipelineJobs"
) {
header("Authorization", "Bearer ${config.token}")
contentType(ContentType.Application.Json)
parameter("pipelineJobId", pipelineJobId)
setBody(pipelineJob)
}

return if (response.status.isSuccess()) response.body<PipelineJob?>()
else throw GcpClientException(response.status, response.bodyAsText())
}

suspend fun cancel(pipelineJobName: String): Unit {
val response =
http.post(
"https://${config.location}-aiplatform.googleapis.com/v1/projects/${config.projectId}/locations/${config.location}/pipelineJobs/$pipelineJobName:cancel"
) {
header("Authorization", "Bearer ${config.token}")
contentType(ContentType.Application.Json)
}

return if (response.status.isSuccess()) {} else
throw GcpClientException(response.status, response.bodyAsText())
}

suspend fun delete(pipelineJobName: String): Operation {
val response =
http.delete(
"https://${config.location}-aiplatform.googleapis.com/v1/projects/${config.projectId}/locations/${config.location}/pipelineJobs/$pipelineJobName"
) {
header("Authorization", "Bearer ${config.token}")
contentType(ContentType.Application.Json)
}

return if (response.status.isSuccess()) response.body<Operation>()
else throw GcpClientException(response.status, response.bodyAsText())
}

class GcpClientException(val httpStatusCode: HttpStatusCode, val error: String) :
IllegalStateException("$httpStatusCode: $error")
}

0 comments on commit 5e84a22

Please sign in to comment.