Skip to content

Commit

Permalink
Rewire tests container setup lifecycle
Browse files Browse the repository at this point in the history
  • Loading branch information
raulraja committed May 28, 2024
1 parent 6cbf754 commit 984e97d
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 103 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,28 @@ import kotlinx.coroutines.runBlocking
import org.junit.jupiter.api.Test

class EnumClassificationTest : OllamaTests() {

@Test
fun `enum classification`() {
fun `positive sentiment`() {
runBlocking {
val models = setOf(OllamaModels.LLama3_8B)
val sentiments =
val sentiment =
ollama<Sentiment>(
models = models,
prompt = "The sentiment of this text is positive.",
model = OllamaModels.Gemma2B,
prompt = "The context of the situation is very positive.",
)
expectSentiment(Sentiment.POSITIVE, sentiments, models)
assert(sentiment == Sentiment.POSITIVE) { "Expected POSITIVE but got $sentiment" }
}
}

companion object {
internal fun expectSentiment(
expected: Sentiment,
sentiments: List<Sentiment>,
models: Set<String>
) {
assert(sentiments.size == models.size) {
"Expected ${models.size} results but got ${sentiments.size}"
}
sentiments.forEach { sentiment ->
assert(sentiment == expected) { "Expected $expected but got $sentiment" }
}
@Test
fun `negative sentiment`() {
runBlocking {
val sentiment =
ollama<Sentiment>(
model = OllamaModels.LLama3_8B,
prompt = "The context of the situation is very negative.",
)
assert(sentiment == Sentiment.NEGATIVE) { "Expected NEGATIVE but got $sentiment" }
}
}
}

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
package com.xebia.functional.xef.ollama.tests

import arrow.fx.coroutines.parMap
import com.xebia.functional.openai.generated.api.Chat
import com.github.dockerjava.api.model.Image
import com.xebia.functional.openai.generated.model.CreateChatCompletionRequestModel
import com.xebia.functional.xef.AI
import com.xebia.functional.xef.Config
import com.xebia.functional.xef.OpenAI
import io.github.oshai.kotlinlogging.KotlinLogging
import kotlinx.coroutines.Dispatchers
import java.util.concurrent.ConcurrentHashMap
import org.junit.jupiter.api.AfterAll
import org.junit.jupiter.api.BeforeAll
import org.testcontainers.DockerClientFactory
import org.testcontainers.ollama.OllamaContainer
import org.testcontainers.utility.DockerImageName

Expand All @@ -19,75 +18,83 @@ abstract class OllamaTests {

companion object {
private const val OLLAMA_IMAGE = "ollama/ollama:0.1.26"
private const val NEW_IMAGE_NAME = "ollama/ollama:test"

val ollama: OllamaContainer by lazy {
// check if the new image is already present otherwise pull the image
if (DockerImageName.parse(NEW_IMAGE_NAME).asCompatibleSubstituteFor(OLLAMA_IMAGE) != null) {
OllamaContainer(DockerImageName.parse(NEW_IMAGE_NAME))
private val registeredContainers: MutableMap<String, OllamaContainer> = ConcurrentHashMap()

@PublishedApi
internal fun useModel(model: String): OllamaContainer =
if (registeredContainers.containsKey(model)) {
registeredContainers[model]!!
} else {
OllamaContainer(DockerImageName.parse(OLLAMA_IMAGE))
ollamaContainer(model)
}
}

@BeforeAll
@JvmStatic
fun setup() {
ollama.start()
ollama.commitToImage(NEW_IMAGE_NAME)
private fun ollamaContainer(model: String, imageName: String = model): OllamaContainer {
if (registeredContainers.containsKey(model)) {
return registeredContainers[model]!!
}
// create the new image if it is not already a docker image
val listImagesCmd: List<Image> =
DockerClientFactory.lazyClient().listImagesCmd().withImageNameFilter(imageName).exec()

val ollama =
if (listImagesCmd.isEmpty()) {
// ship container emoji: 🚢
println("🐳 Creating a new Ollama container with $model image...")
val ollama = OllamaContainer(OLLAMA_IMAGE)
ollama.start()
println("🐳 Pulling $model image...")
ollama.execInContainer("ollama", "pull", model)
println("🐳 Committing $model image...")
ollama.commitToImage(imageName)
ollama.withReuse(true)
} else {
println("🐳 Using existing Ollama container with $model image...")
// Substitute the default Ollama image with our model variant
val ollama =
OllamaContainer(
DockerImageName.parse(imageName).asCompatibleSubstituteFor("ollama/ollama")
)
.withReuse(true)
ollama.start()
ollama
}
println("🐳 Starting Ollama container with $model image...")
registeredContainers[model] = ollama
ollama.execInContainer("ollama", "run", model)
return ollama
}

@AfterAll
@JvmStatic
fun teardown() {
ollama.commitToImage(NEW_IMAGE_NAME)
ollama.stop()
registeredContainers.forEach { (model, container) ->
println("🐳 Stopping Ollama container for model $model")
container.stop()
}
}
}

suspend inline fun <reified A> ollama(
models: Set<String>,
protected suspend inline fun <reified A> ollama(
model: String,
prompt: String,
config: Config = Config(baseUrl = ollamaBaseUrl(), supportsLogitBias = false),
api: Chat = OpenAI(config = config, logRequests = true).chat,
): List<A> {
// pull all models
models.parMap(context = Dispatchers.IO) { model ->
logger.info { "🚢 Pulling model $model" }
val pullResult = ollama.execInContainer("ollama", "pull", model)
if (pullResult.exitCode != 0) {
logger.error { pullResult.stderr }
throw RuntimeException("Failed to pull model $model")
}
logger.info { pullResult.stdout }
logger.info { "🚢 Pulled $model" }
}
// run all models
models.parMap(context = Dispatchers.IO) { model ->
logger.info { "🚀 Starting model $model" }
val runResult = ollama.execInContainer("ollama", "run", model)
if (runResult.exitCode != 0) {
logger.error { runResult.stderr }
throw RuntimeException("Failed to run model $model")
}
logger.info { runResult.stdout }
println("🚀 Started $model")
}
// run inference on all models
return models.parMap(context = Dispatchers.IO) { model ->
logger.info { "🚀 Running inference on model $model" }
val result: A =
AI(
prompt = prompt,
config = config,
api = api,
model = CreateChatCompletionRequestModel.Custom(model),
)
logger.info { "🚀 Inference on model $model: $result" }
result
}
): A {
useModel(model)
val config = Config(supportsLogitBias = false, baseUrl = ollamaBaseUrl(model))
val api = OpenAI(config = config, logRequests = true).chat
val result: A =
AI(
prompt = prompt,
config = config.copy(),
api = api,
model = CreateChatCompletionRequestModel.Custom(model),
)
logger.info { "🚀 Inference on model $model: $result" }
return result
}

fun ollamaBaseUrl(): String =
"http://${ollama.host}:${ollama.getMappedPort(ollama.exposedPorts.first())}/v1/"
fun ollamaBaseUrl(model: String): String {
val ollama = registeredContainers[model]!!
return "http://${ollama.host}:${ollama.getMappedPort(ollama.exposedPorts.first())}/v1/"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,4 @@ import kotlinx.serialization.Serializable
enum class Sentiment {
POSITIVE,
NEGATIVE,
NEUTRAL,
MIXED,
UNKNOWN
}

0 comments on commit 984e97d

Please sign in to comment.