Skip to content

Commit

Permalink
OpenAIEmbeddings to core module (#380)
Browse files Browse the repository at this point in the history
* OpenAIEmbeddings to core module

* HuggingFaceLocalEmbeddings renamed

* Comments addressed

* Removing EmbeddingService

* Simplify embedDocuments
  • Loading branch information
javipacheco authored Sep 5, 2023
1 parent e618b9b commit de40a7b
Show file tree
Hide file tree
Showing 23 changed files with 93 additions and 136 deletions.

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,8 +1,28 @@
package com.xebia.functional.xef.llm

import arrow.fx.coroutines.parMap
import com.xebia.functional.xef.llm.models.embeddings.Embedding
import com.xebia.functional.xef.llm.models.embeddings.EmbeddingRequest
import com.xebia.functional.xef.llm.models.embeddings.EmbeddingResult
import com.xebia.functional.xef.llm.models.embeddings.RequestConfig

interface Embeddings : LLM {
suspend fun createEmbeddings(request: EmbeddingRequest): EmbeddingResult

suspend fun embedDocuments(
texts: List<String>,
requestConfig: RequestConfig,
chunkSize: Int?
): List<Embedding> =
if (texts.isEmpty()) emptyList()
else
texts
.chunked(chunkSize ?: 400)
.parMap { createEmbeddings(EmbeddingRequest(name, texts, requestConfig.user.id)).data }
.flatten()

suspend fun embedQuery(text: String, requestConfig: RequestConfig): List<Embedding> =
if (text.isNotEmpty()) embedDocuments(listOf(text), requestConfig, null) else emptyList()

companion object
}
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
package com.xebia.functional.xef.llm.models.embeddings

class Embedding(val `object`: String, val embedding: List<Float>, val index: Int)
class Embedding(val embedding: List<Float>)
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package com.xebia.functional.xef.store

import com.xebia.functional.xef.embeddings.Embedding
import com.xebia.functional.xef.llm.models.embeddings.Embedding

/**
* A way of composing two [VectorStore] instances together, this class will **first search** [top],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ package com.xebia.functional.xef.store
import arrow.atomic.Atomic
import arrow.atomic.getAndUpdate
import arrow.atomic.update
import com.xebia.functional.xef.embeddings.Embedding
import com.xebia.functional.xef.embeddings.Embeddings
import com.xebia.functional.xef.llm.Embeddings
import com.xebia.functional.xef.llm.models.embeddings.Embedding
import com.xebia.functional.xef.llm.models.embeddings.RequestConfig
import kotlin.math.sqrt

Expand Down Expand Up @@ -54,8 +54,7 @@ private constructor(private val embeddings: Embeddings, private val state: Atomi
}

override suspend fun addTexts(texts: List<String>) {
val embeddingsList =
embeddings.embedDocuments(texts, chunkSize = null, requestConfig = requestConfig)
val embeddingsList = embeddings.embedDocuments(texts, requestConfig = requestConfig, null)
state.getAndUpdate { prevState ->
val newEmbeddings = prevState.precomputedEmbeddings + texts.zip(embeddingsList)
State(prevState.orderedMemories, prevState.documents + texts, newEmbeddings)
Expand All @@ -80,9 +79,9 @@ private constructor(private val embeddings: Embeddings, private val state: Atomi
}

private fun Embedding.cosineSimilarity(other: Embedding): Double {
val dotProduct = this.data.zip(other.data).sumOf { (a, b) -> (a * b).toDouble() }
val magnitudeA = sqrt(this.data.sumOf { (it * it).toDouble() })
val magnitudeB = sqrt(other.data.sumOf { (it * it).toDouble() })
val dotProduct = this.embedding.zip(other.embedding).sumOf { (a, b) -> (a * b).toDouble() }
val magnitudeA = sqrt(this.embedding.sumOf { (it * it).toDouble() })
val magnitudeB = sqrt(other.embedding.sumOf { (it * it).toDouble() })
return dotProduct / (magnitudeA * magnitudeB)
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package com.xebia.functional.xef.store

import com.xebia.functional.xef.embeddings.Embedding
import com.xebia.functional.xef.llm.models.embeddings.Embedding
import kotlin.jvm.JvmStatic

interface VectorStore {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,26 @@
package com.xebia.functional.xef.data

import com.xebia.functional.xef.embeddings.Embedding
import com.xebia.functional.xef.embeddings.Embeddings
import com.xebia.functional.xef.llm.Embeddings
import com.xebia.functional.xef.llm.models.embeddings.Embedding
import com.xebia.functional.xef.llm.models.embeddings.EmbeddingRequest
import com.xebia.functional.xef.llm.models.embeddings.EmbeddingResult
import com.xebia.functional.xef.llm.models.embeddings.RequestConfig
import com.xebia.functional.xef.llm.models.usage.Usage

class TestEmbeddings : Embeddings {

override val name: String
get() = "test-embeddings"

override suspend fun embedDocuments(
texts: List<String>,
chunkSize: Int?,
requestConfig: RequestConfig
requestConfig: RequestConfig,
chunkSize: Int?
): List<Embedding> = emptyList()

override suspend fun embedQuery(text: String, requestConfig: RequestConfig): List<Embedding> =
emptyList()

override suspend fun createEmbeddings(request: EmbeddingRequest): EmbeddingResult =
EmbeddingResult(emptyList(), Usage.ZERO)
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@ package com.xebia.functional.xef.conversation.streaming

import com.xebia.functional.xef.conversation.Conversation
import com.xebia.functional.xef.conversation.llm.openai.OpenAI
import com.xebia.functional.xef.conversation.llm.openai.OpenAIEmbeddings
import com.xebia.functional.xef.llm.Chat
import com.xebia.functional.xef.prompt.Prompt
import com.xebia.functional.xef.store.LocalVectorStore

suspend fun main() {
val chat: Chat = OpenAI().DEFAULT_CHAT
val embeddings = OpenAIEmbeddings(OpenAI().DEFAULT_EMBEDDING)
val embeddings = OpenAI().DEFAULT_EMBEDDING
val scope = Conversation(LocalVectorStore(embeddings))
chat.promptStreaming(prompt = Prompt("What is the meaning of life?"), scope = scope).collect {
print(it)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package com.xebia.functional.xef.conversation.streaming

import com.xebia.functional.xef.conversation.Conversation
import com.xebia.functional.xef.conversation.llm.openai.OpenAI
import com.xebia.functional.xef.conversation.llm.openai.OpenAIEmbeddings
import com.xebia.functional.xef.llm.StreamedFunction
import com.xebia.functional.xef.prompt.Prompt
import com.xebia.functional.xef.store.LocalVectorStore
Expand All @@ -17,8 +16,7 @@ suspend fun main() {

val model = OpenAI(host = "http://localhost:8081/").DEFAULT_SERIALIZATION

val scope =
Conversation(LocalVectorStore(OpenAIEmbeddings(OpenAI.FromEnvironment.DEFAULT_EMBEDDING)))
val scope = Conversation(LocalVectorStore(OpenAI.FromEnvironment.DEFAULT_EMBEDDING))

model
.promptStreaming(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@ import com.xebia.functional.xef.store.LocalVectorStore
import com.xebia.functional.xef.store.VectorStore

suspend inline fun <A> conversation(
store: VectorStore = LocalVectorStore(HuggingFaceLocalEmbeddings.DEFAULT),
noinline block: suspend Conversation.() -> A
store: VectorStore = LocalVectorStore(HuggingFaceLocalEmbeddings.DEFAULT),
noinline block: suspend Conversation.() -> A
): A = block(Conversation(store))
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
package com.xebia.functional.gpt4all

import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer
import com.xebia.functional.xef.embeddings.Embedding as XefEmbedding
import com.xebia.functional.xef.embeddings.Embeddings
import com.xebia.functional.xef.llm.Embeddings
import com.xebia.functional.xef.llm.models.embeddings.Embedding
import com.xebia.functional.xef.llm.models.embeddings.EmbeddingRequest
import com.xebia.functional.xef.llm.models.embeddings.EmbeddingResult
import com.xebia.functional.xef.llm.models.embeddings.RequestConfig
import com.xebia.functional.xef.llm.models.usage.Usage

class HuggingFaceLocalEmbeddings(name: String, artifact: String) : com.xebia.functional.xef.llm.Embeddings, Embeddings {
class HuggingFaceLocalEmbeddings(name: String, artifact: String) : Embeddings {

private val tokenizer = HuggingFaceTokenizer.newInstance("$name/$artifact")

Expand All @@ -18,20 +17,17 @@ class HuggingFaceLocalEmbeddings(name: String, artifact: String) : com.xebia.fun
override suspend fun createEmbeddings(request: EmbeddingRequest): EmbeddingResult {
val embedings = tokenizer.batchEncode(request.input)
return EmbeddingResult(
data = embedings.mapIndexed { n, em -> Embedding("embedding", em.ids.map { it.toFloat() }, n) },
data = embedings.map { Embedding(it.ids.map { it.toFloat() }) },
usage = Usage.ZERO
)
}

override suspend fun embedDocuments(
texts: List<String>,
chunkSize: Int?,
requestConfig: RequestConfig
): List<XefEmbedding> =
tokenizer.batchEncode(texts).map { em -> XefEmbedding(em.ids.map { it.toFloat() }) }

override suspend fun embedQuery(text: String, requestConfig: RequestConfig): List<XefEmbedding> =
embedDocuments(listOf(text), null, requestConfig)
requestConfig: RequestConfig,
chunkSize: Int?
): List<Embedding> =
tokenizer.batchEncode(texts).map { em -> Embedding(em.ids.map { it.toFloat() }) } // TODO we need to remove the index

companion object {
@JvmField
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,15 @@ class GCP(projectId: String? = null, location: VertexAIRegion? = null, token: St

@JvmSynthetic
suspend fun <A> conversation(block: suspend Conversation.() -> A): A =
block(conversation(LocalVectorStore(GcpEmbeddings(FromEnvironment.DEFAULT_EMBEDDING))))
block(conversation(LocalVectorStore(FromEnvironment.DEFAULT_EMBEDDING)))

@JvmStatic
@JvmOverloads
fun conversation(
store: VectorStore = LocalVectorStore(GcpEmbeddings(FromEnvironment.DEFAULT_EMBEDDING))
store: VectorStore = LocalVectorStore(FromEnvironment.DEFAULT_EMBEDDING)
): PlatformConversation = Conversation(store)
}
}

suspend inline fun <A> GCP.conversation(noinline block: suspend Conversation.() -> A): A =
block(Conversation(LocalVectorStore(GcpEmbeddings(DEFAULT_EMBEDDING))))
block(Conversation(LocalVectorStore(DEFAULT_EMBEDDING)))

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,12 @@ class GcpModel(modelId: String, config: GcpConfig) : Chat, Completion, AutoClose
}

override suspend fun createEmbeddings(request: EmbeddingRequest): EmbeddingResult {
fun requestToEmbedding(index: Int, it: GcpClient.EmbeddingPredictions): Embedding =
Embedding("embedding", it.embeddings.values.map(Double::toFloat), index = index)
fun requestToEmbedding(it: GcpClient.EmbeddingPredictions): Embedding =
Embedding(it.embeddings.values.map(Double::toFloat))

val response = client.embeddings(request)
return EmbeddingResult(
data = response.predictions.mapIndexed(::requestToEmbedding),
data = response.predictions.map(::requestToEmbedding),
usage = usage(response),
)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
package com.xebia.functional.xef.store

import com.xebia.functional.xef.embeddings.Embedding
import com.xebia.functional.xef.embeddings.Embeddings
import com.xebia.functional.xef.llm.Embeddings
import com.xebia.functional.xef.llm.models.chat.Message
import com.xebia.functional.xef.llm.models.chat.Role
import com.xebia.functional.xef.llm.models.embeddings.Embedding
import com.xebia.functional.xef.llm.models.embeddings.RequestConfig
import org.apache.lucene.analysis.standard.StandardAnalyzer
import org.apache.lucene.document.Document
Expand Down Expand Up @@ -89,7 +89,7 @@ open class Lucene(

override suspend fun similaritySearchByVector(embedding: Embedding, limit: Int): List<String> {
requireNotNull(embeddings) { "no embeddings were computed for this model" }
val luceneQuery = KnnFloatVectorQuery("embedding", embedding.data.toFloatArray(), limit)
val luceneQuery = KnnFloatVectorQuery("embedding", embedding.embedding.toFloatArray(), limit)
val searcher = IndexSearcher(DirectoryReader.open(writer))
return searcher.search(luceneQuery, limit).extract(searcher)
}
Expand Down Expand Up @@ -150,4 +150,4 @@ fun InMemoryLuceneBuilder(
InMemoryLucene(path, writerConfig, embeddings.takeIf { useAIEmbeddings }, similarity)
}

fun List<Embedding>.toFloatArray(): FloatArray = flatMap { it.data }.toFloatArray()
fun List<Embedding>.toFloatArray(): FloatArray = flatMap { it.embedding }.toFloatArray()
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
package com.xebia.functional.xef.store

import com.xebia.functional.xef.embeddings.Embedding
import com.xebia.functional.xef.embeddings.Embeddings
import com.xebia.functional.xef.llm.Embeddings
import com.xebia.functional.xef.llm.models.chat.Message
import com.xebia.functional.xef.llm.models.chat.Role
import com.xebia.functional.xef.llm.models.embeddings.Embedding
import com.xebia.functional.xef.llm.models.embeddings.RequestConfig
import com.xebia.functional.xef.store.postgresql.*
import kotlinx.uuid.UUID
Expand Down Expand Up @@ -95,14 +95,14 @@ class PGVectorStore(

override suspend fun addTexts(texts: List<String>): Unit =
dataSource.connection {
val embeddings = embeddings.embedDocuments(texts, chunkSize, requestConfig)
val embeddings = embeddings.embedDocuments(texts, requestConfig, chunkSize)
val collection = getCollection(collectionName)
texts.zip(embeddings) { text, embedding ->
val uuid = UUID.generateUUID()
update(addNewText) {
bind(uuid.toString())
bind(collection.uuid.toString())
bind(embedding.data.toString())
bind(embedding.embedding.toString())
bind(text)
}
}
Expand All @@ -121,7 +121,7 @@ class PGVectorStore(
searchSimilarDocument(distanceStrategy),
{
bind(collection.uuid.toString())
bind(embeddings[0].data.toString())
bind(embeddings[0].embedding.toString())
bind(limit)
}
) {
Expand All @@ -136,7 +136,7 @@ class PGVectorStore(
searchSimilarDocument(distanceStrategy),
{
bind(collection.uuid.toString())
bind(embedding.data.toString())
bind(embedding.embedding.toString())
bind(limit)
}
) {
Expand Down
19 changes: 14 additions & 5 deletions integrations/postgresql/src/test/kotlin/xef/PGVectorStoreSpec.kt
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package xef

import com.xebia.functional.xef.embeddings.Embedding
import com.xebia.functional.xef.embeddings.Embeddings
import com.xebia.functional.xef.llm.Embeddings
import com.xebia.functional.xef.llm.models.embeddings.Embedding
import com.xebia.functional.xef.llm.models.chat.Message
import com.xebia.functional.xef.llm.models.chat.Role
import com.xebia.functional.xef.llm.models.embeddings.EmbeddingRequest
import com.xebia.functional.xef.llm.models.embeddings.EmbeddingResult
import com.xebia.functional.xef.llm.models.embeddings.RequestConfig
import com.xebia.functional.xef.store.ConversationId
import com.xebia.functional.xef.store.Memory
Expand Down Expand Up @@ -123,10 +125,17 @@ private fun Embeddings.Companion.mock(
object : Embeddings {
override suspend fun embedDocuments(
texts: List<String>,
chunkSize: Int?,
requestConfig: RequestConfig
): List<Embedding> = embedDocuments(texts, chunkSize, requestConfig)
requestConfig: RequestConfig,
chunkSize: Int?
): List<Embedding> = embedDocuments(texts, requestConfig, chunkSize)

override suspend fun embedQuery(text: String, requestConfig: RequestConfig): List<Embedding> =
embedQuery(text, requestConfig)

override suspend fun createEmbeddings(request: EmbeddingRequest): EmbeddingResult =
createEmbeddings(request)


override val name: String
get() = "embeddings"
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package com.xebia.functional.xef.conversation

import com.xebia.functional.xef.embeddings.Embeddings
import com.xebia.functional.xef.llm.Embeddings
import com.xebia.functional.xef.store.LocalVectorStore
import com.xebia.functional.xef.store.VectorStore

Expand Down
Loading

0 comments on commit de40a7b

Please sign in to comment.