Skip to content

Commit

Permalink
add image embedding background worker
Browse files Browse the repository at this point in the history
  • Loading branch information
cbh123 committed Sep 27, 2023
1 parent aee1374 commit 9879e83
Show file tree
Hide file tree
Showing 8 changed files with 130 additions and 27 deletions.
30 changes: 25 additions & 5 deletions lib/emoji/embeddings.ex
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,44 @@ defmodule Emoji.Embeddings do
@doc """
Creates an embedding and returns it in binary form.
"""
def create(
text,
embeddings_model \\ "daanelson/imagebind:0383f62e173dc821ec52663ed22a076d9c970549c209666ac3db181618b7a304"
) do
def create(text, embeddings_model) do
embeddings_model
|> Replicate.run(text_input: text, modality: "text")
|> Nx.tensor()
|> Nx.to_binary()
end

@doc """
Creates an image embedding given an image url and returns it in binary form.
"""
def create_image(image_url, embeddings_model) do
image_uri =
image_url |> Req.get!() |> Map.get(:body) |> binary_to_data_uri("image/png")

embeddings_model
|> Replicate.run(input: image_uri, modality: "vision")
|> Nx.tensor()
|> Nx.to_binary()
end

def clean(text) do
text
|> String.replace("A TOK emoji of a", "")
|> String.trim()
end

defp binary_to_data_uri(binary, mime_type) do
base64 = Base.encode64(binary)
"data:#{mime_type};base64,#{base64}"
end

def search_emojis(query, num_results \\ 9) do
embedding = create(query) |> Nx.from_binary(:f32)
embedding =
create(
query,
"daanelson/imagebind:0383f62e173dc821ec52663ed22a076d9c970549c209666ac3db181618b7a304"
)
|> Nx.from_binary(:f32)

%{labels: labels, distances: distances} =
Emoji.Embeddings.Index.search(embedding, num_results)
Expand Down
8 changes: 8 additions & 0 deletions lib/emoji/embeddings/index.ex
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,16 @@ defmodule Emoji.Embeddings.Index do
|> Enum.reduce(index, fn prediction, index ->
id = prediction.id
embedding = prediction.embedding
# image_embedding = prediction.image_embedding

HNSWLib.Index.add_items(index, Nx.from_binary(embedding, :f32), ids: Nx.tensor([id]))

# if image_embedding != nil do
# HNSWLib.Index.add_items(index, Nx.from_binary(image_embedding, :f32),
# ids: Nx.tensor([id])
# )
# end

index
end)

Expand Down
80 changes: 61 additions & 19 deletions lib/emoji/embeddings/worker.ex
Original file line number Diff line number Diff line change
Expand Up @@ -22,31 +22,73 @@ defmodule Emoji.Embeddings.Worker do
Process.send_after(self(), :work, 5000)
end

defp should_generate?() do
defp should_generate_text_embedding?() do
case Application.get_env(:emoji, :env) do
:prod -> Predictions.count_predictions_with_embeddings() < 100_000
:dev -> Predictions.count_predictions_with_embeddings() < 50
:prod -> Predictions.count_predictions_with_text_embeddings() < 100_000
:dev -> Predictions.count_predictions_with_text_embeddings() < 50
_ -> false
end
end

defp should_generate_image_embedding?() do
case Application.get_env(:emoji, :env) do
:prod -> Predictions.count_predictions_with_image_embeddings() < 10_000
:dev -> Predictions.count_predictions_with_image_embeddings() < 50
_ -> false
end
end

defp create_text_embedding(prediction) do
Logger.info("Creating text embeddings for #{prediction.id}")

embedding =
prediction.prompt
|> Embeddings.clean()
|> Embeddings.create(@embeddings_model)

{:ok, prediction} =
Predictions.update_prediction(prediction, %{
embedding: embedding,
embedding_model: @embeddings_model
})

prediction
end

defp create_image_embedding(prediction) do
Logger.info("Creating image embeddings for #{prediction.id}")

image_embedding =
prediction.no_bg_output
|> Embeddings.create_image(@embeddings_model)

{:ok, prediction} =
Predictions.update_prediction(prediction, %{
image_embedding: image_embedding
})

prediction
end

def handle_info(:work, state) do
if should_generate?() do
prediction = Predictions.get_random_prediction_without_embeddings()
Logger.info("Creating embeddings for #{prediction.id}")

embedding =
prediction.prompt
|> Embeddings.clean()
|> Embeddings.create(@embeddings_model)

{:ok, _prediction} =
Predictions.update_prediction(prediction, %{
embedding: embedding,
embedding_model: @embeddings_model
})

schedule_creation()
cond do
should_generate_image_embedding?() and should_generate_text_embedding?() ->
Predictions.get_random_prediction_without_text_embeddings() |> create_text_embedding()
Predictions.get_random_prediction_without_image_embeddings() |> create_image_embedding()

schedule_creation()

should_generate_image_embedding?() ->
Predictions.get_random_prediction_without_image_embeddings() |> create_image_embedding()

schedule_creation()

should_generate_text_embedding?() ->
Predictions.get_random_prediction_without_text_embeddings() |> create_text_embedding()
schedule_creation()

true ->
nil
end

{:noreply, state}
Expand Down
20 changes: 18 additions & 2 deletions lib/emoji/predictions.ex
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,21 @@ defmodule Emoji.Predictions do
from(p in Prediction, where: p.id in ^ids and not is_nil(p.emoji_output)) |> Repo.all()
end

def count_predictions_with_embeddings() do
def count_predictions_with_text_embeddings() do
Repo.aggregate(
from(p in Prediction, where: not p.embedding |> is_nil()),
:count
)
end

def get_random_prediction_without_embeddings() do
def count_predictions_with_image_embeddings() do
Repo.aggregate(
from(p in Prediction, where: not p.image_embedding |> is_nil()),
:count
)
end

def get_random_prediction_without_text_embeddings() do
from(p in Prediction,
where: is_nil(p.embedding) and not is_nil(p.emoji_output) and p.score != 10,
order_by: fragment("RANDOM()"),
Expand All @@ -28,6 +35,15 @@ defmodule Emoji.Predictions do
|> Repo.one!()
end

def get_random_prediction_without_image_embeddings() do
from(p in Prediction,
where: is_nil(p.image_embedding) and not is_nil(p.emoji_output) and p.score != 10,
order_by: fragment("RANDOM()"),
limit: 1
)
|> Repo.one!()
end

@doc """
Returns the list of predictions.
Expand Down
2 changes: 2 additions & 0 deletions lib/emoji/predictions/prediction.ex
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ defmodule Emoji.Predictions.Prediction do
field :moderation_score, :integer
field :moderator, :string
field :embedding, :binary
field :image_embedding, :binary
field :embedding_model, :string

timestamps()
Expand All @@ -34,6 +35,7 @@ defmodule Emoji.Predictions.Prediction do
:moderation_score,
:moderator,
:embedding,
:image_embedding,
:embedding_model
])
|> validate_required([:prompt])
Expand Down
4 changes: 3 additions & 1 deletion mix.exs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ defmodule Emoji.MixProject do
{:ex_aws_s3, "~> 2.0"},
{:nx, "~> 0.4"},
{:hnswlib, "~> 0.1.2"},
{:credo, "~> 1.7", only: [:dev, :test], runtime: false}
{:credo, "~> 1.7", only: [:dev, :test], runtime: false},
{:image, "~> 0.37"},
{:req, "~> 0.4.3"}
]
end

Expand Down
4 changes: 4 additions & 0 deletions mix.lock
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
"hpax": {:hex, :hpax, "0.1.2", "09a75600d9d8bbd064cdd741f21fc06fc1f4cf3d0fcc335e5aa19be1a7235c84", [:mix], [], "hexpm", "2c87843d5a23f5f16748ebe77969880e29809580efdaccd615cd3bed628a8c13"},
"httpoison": {:hex, :httpoison, "2.1.0", "655fd9a7b0b95ee3e9a3b535cf7ac8e08ef5229bab187fa86ac4208b122d934b", [:mix], [{:hackney, "~> 1.17", [hex: :hackney, repo: "hexpm", optional: false]}], "hexpm", "fc455cb4306b43827def4f57299b2d5ac8ac331cb23f517e734a4b78210a160c"},
"idna": {:hex, :idna, "6.1.1", "8a63070e9f7d0c62eb9d9fcb360a7de382448200fbbd1b106cc96d3d8099df8d", [:rebar3], [{:unicode_util_compat, "~> 0.7.0", [hex: :unicode_util_compat, repo: "hexpm", optional: false]}], "hexpm", "92376eb7894412ed19ac475e4a86f7b413c1b9fbb5bd16dccd57934157944cea"},
"image": {:hex, :image, "0.38.2", "d444e0e9434558c6e7f16c9fc4f0eb6924547ed0f16ad4953af17543e33c89f6", [:mix], [{:bumblebee, "~> 0.3", [hex: :bumblebee, repo: "hexpm", optional: true]}, {:evision, "~> 0.1.33", [hex: :evision, repo: "hexpm", optional: true]}, {:exla, "~> 0.5", [hex: :exla, repo: "hexpm", optional: true]}, {:jason, "~> 1.4", [hex: :jason, repo: "hexpm", optional: true]}, {:kino, "~> 0.7", [hex: :kino, repo: "hexpm", optional: true]}, {:nx, "~> 0.5", [hex: :nx, repo: "hexpm", optional: true]}, {:phoenix_html, "~> 2.14 or ~> 3.2", [hex: :phoenix_html, repo: "hexpm", optional: false]}, {:plug, "~> 1.13", [hex: :plug, repo: "hexpm", optional: true]}, {:rustler, "> 0.0.0", [hex: :rustler, repo: "hexpm", optional: true]}, {:sweet_xml, "~> 0.7", [hex: :sweet_xml, repo: "hexpm", optional: false]}, {:vix, "~> 0.17", [hex: :vix, repo: "hexpm", optional: false]}], "hexpm", "9f5bb2560ff43a5b0a33a333be444b958b89ebc62a54683ec1612b75ced6e307"},
"jason": {:hex, :jason, "1.4.1", "af1504e35f629ddcdd6addb3513c3853991f694921b1b9368b0bd32beb9f1b63", [:mix], [{:decimal, "~> 1.0 or ~> 2.0", [hex: :decimal, repo: "hexpm", optional: true]}], "hexpm", "fbb01ecdfd565b56261302f7e1fcc27c4fb8f32d56eab74db621fc154604a7a1"},
"metrics": {:hex, :metrics, "1.0.1", "25f094dea2cda98213cecc3aeff09e940299d950904393b2a29d191c346a8486", [:rebar3], [], "hexpm", "69b09adddc4f74a40716ae54d140f93beb0fb8978d8636eaded0c31b6f099f16"},
"mime": {:hex, :mime, "2.0.5", "dc34c8efd439abe6ae0343edbb8556f4d63f178594894720607772a041b04b02", [:mix], [], "hexpm", "da0d64a365c45bc9935cc5c8a7fc5e49a0e0f9932a761c55d6c52b142780a05c"},
Expand Down Expand Up @@ -58,9 +59,11 @@
"progress_bar": {:hex, :progress_bar, "3.0.0", "f54ff038c2ac540cfbb4c2bfe97c75e7116ead044f3c2b10c9f212452194b5cd", [:mix], [{:decimal, "~> 2.0", [hex: :decimal, repo: "hexpm", optional: false]}], "hexpm", "6981c2b25ab24aecc91a2dc46623658e1399c21a2ae24db986b90d678530f2b7"},
"ranch": {:hex, :ranch, "1.8.0", "8c7a100a139fd57f17327b6413e4167ac559fbc04ca7448e9be9057311597a1d", [:make, :rebar3], [], "hexpm", "49fbcfd3682fab1f5d109351b61257676da1a2fdbe295904176d5e521a2ddfe5"},
"replicate": {:hex, :replicate, "1.1.1", "77f017b7bc0af2df6b3abfcf06a1136a3279823e881084a8b57e3d90c28eacff", [:mix], [{:httpoison, "~> 2.0", [hex: :httpoison, repo: "hexpm", optional: false]}, {:jason, "~> 1.2", [hex: :jason, repo: "hexpm", optional: false]}], "hexpm", "73ce49fc3185ec5988f25265ae8562f64c1e8bd0ac040eed1c62f9527759b1bb"},
"req": {:hex, :req, "0.4.3", "bb4cd1661a234b9c779b984dd137761f7ff705f45d0008ba40c8f420a4307b43", [:mix], [{:brotli, "~> 0.3.1", [hex: :brotli, repo: "hexpm", optional: true]}, {:ezstd, "~> 1.0", [hex: :ezstd, repo: "hexpm", optional: true]}, {:finch, "~> 0.9", [hex: :finch, repo: "hexpm", optional: false]}, {:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: false]}, {:mime, "~> 1.6 or ~> 2.0", [hex: :mime, repo: "hexpm", optional: false]}, {:nimble_csv, "~> 1.0", [hex: :nimble_csv, repo: "hexpm", optional: true]}, {:plug, "~> 1.0", [hex: :plug, repo: "hexpm", optional: true]}], "hexpm", "9bc88c84f101cfe884260d3413b72aaad6d94ccedccc1f2bcef8e94bd68c5536"},
"rustler_precompiled": {:hex, :rustler_precompiled, "0.6.3", "f838d94bc35e1844973ee7266127b156fdc962e9e8b7ff666c8fb4fed7964d23", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:rustler, "~> 0.23", [hex: :rustler, repo: "hexpm", optional: true]}], "hexpm", "e18ecca3669a7454b3a2be75ae6c3ef01d550bc9a8cf5fbddcfff843b881d7c6"},
"safetensors": {:hex, :safetensors, "0.1.1", "b5859a010fb56249ecfba4799d316e96b89152576af2db7657786c55dcf2f5b6", [:mix], [{:jason, "~> 1.4", [hex: :jason, repo: "hexpm", optional: false]}, {:nx, "~> 0.5", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "dfbb525bf3debb2e2d90f840728af70da5d55f6caa091cac4d0891a4eb4c52d5"},
"ssl_verify_fun": {:hex, :ssl_verify_fun, "1.1.7", "354c321cf377240c7b8716899e182ce4890c5938111a1296add3ec74cf1715df", [:make, :mix, :rebar3], [], "hexpm", "fe4c190e8f37401d30167c8c405eda19469f34577987c76dde613e838bbc67f8"},
"sweet_xml": {:hex, :sweet_xml, "0.7.4", "a8b7e1ce7ecd775c7e8a65d501bc2cd933bff3a9c41ab763f5105688ef485d08", [:mix], [], "hexpm", "e7c4b0bdbf460c928234951def54fe87edf1a170f6896675443279e2dbeba167"},
"swoosh": {:hex, :swoosh, "1.11.5", "429dccde78e2f60c6339e96917efecebca9d1f254d2878a150f580d2f782260b", [:mix], [{:cowboy, "~> 1.1 or ~> 2.4", [hex: :cowboy, repo: "hexpm", optional: true]}, {:ex_aws, "~> 2.1", [hex: :ex_aws, repo: "hexpm", optional: true]}, {:finch, "~> 0.6", [hex: :finch, repo: "hexpm", optional: true]}, {:gen_smtp, "~> 0.13 or ~> 1.0", [hex: :gen_smtp, repo: "hexpm", optional: true]}, {:hackney, "~> 1.9", [hex: :hackney, repo: "hexpm", optional: true]}, {:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: false]}, {:mail, "~> 0.2", [hex: :mail, repo: "hexpm", optional: true]}, {:mime, "~> 1.1 or ~> 2.0", [hex: :mime, repo: "hexpm", optional: false]}, {:plug_cowboy, ">= 1.0.0", [hex: :plug_cowboy, repo: "hexpm", optional: true]}, {:telemetry, "~> 0.4.2 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "21ee57dcd68d2f56d3bbe11e76d56d142b221bb12b6018c551cc68442b800040"},
"tailwind": {:hex, :tailwind, "0.1.10", "21ed80ae1f411f747ee513470578acaaa1d0eb40170005350c5b0b6d07e2d624", [:mix], [{:castore, ">= 0.0.0", [hex: :castore, repo: "hexpm", optional: false]}], "hexpm", "e0fc474dfa8ed7a4573851ac69c5fd3ca70fbb0a5bada574d1d657ebc6f2f1f1"},
"telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"},
Expand All @@ -70,6 +73,7 @@
"unicode_util_compat": {:hex, :unicode_util_compat, "0.7.0", "bc84380c9ab48177092f43ac89e4dfa2c6d62b40b8bd132b1059ecc7232f9a78", [:rebar3], [], "hexpm", "25eee6d67df61960cf6a794239566599b09e17e668d3700247bc498638152521"},
"unpickler": {:hex, :unpickler, "0.1.0", "c2262c0819e6985b761e7107546cef96a485f401816be5304a65fdd200d5bd6a", [:mix], [], "hexpm", "e2b3f61e62406187ac52afead8a63bfb4e49394028993f3c4c42712743cab79e"},
"unzip": {:hex, :unzip, "0.8.0", "ee21d87c21b01567317387dab4228ac570ca15b41cfc221a067354cbf8e68c4d", [:mix], [], "hexpm", "ffa67a483efcedcb5876971a50947222e104d5f8fea2c4a0441e6f7967854827"},
"vix": {:hex, :vix, "0.22.0", "17efba59fa1a5d9cab36dbf066aa5d0a40d6b53f7d66380877392212aa7a39c6", [:make, :mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:cc_precompiler, "~> 0.1.4 or ~> 0.2", [hex: :cc_precompiler, repo: "hexpm", optional: false]}, {:elixir_make, "~> 0.7.3 or ~> 0.8", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:kino, "~> 0.7", [hex: :kino, repo: "hexpm", optional: true]}], "hexpm", "04fb539e881539f00eefde22997c3aa05d79aba695f440d2556ee2ad35237bdd"},
"websock": {:hex, :websock, "0.5.3", "2f69a6ebe810328555b6fe5c831a851f485e303a7c8ce6c5f675abeb20ebdadc", [:mix], [], "hexpm", "6105453d7fac22c712ad66fab1d45abdf049868f253cf719b625151460b8b453"},
"websock_adapter": {:hex, :websock_adapter, "0.5.4", "7af8408e7ed9d56578539594d1ee7d8461e2dd5c3f57b0f2a5352d610ddde757", [:mix], [{:bandit, ">= 0.6.0", [hex: :bandit, repo: "hexpm", optional: true]}, {:plug, "~> 1.14", [hex: :plug, repo: "hexpm", optional: false]}, {:plug_cowboy, "~> 2.6", [hex: :plug_cowboy, repo: "hexpm", optional: true]}, {:websock, "~> 0.5", [hex: :websock, repo: "hexpm", optional: false]}], "hexpm", "d2c238c79c52cbe223fcdae22ca0bb5007a735b9e933870e241fce66afb4f4ab"},
"xla": {:hex, :xla, "0.5.0", "fb8a02c02e5a4f4531fbf18a90c325e471037f983f0115d23f510e7dd9a6aa65", [:make, :mix], [{:elixir_make, "~> 0.4", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "571ac797a4244b8ba8552ed0295a54397bd896708be51e4da6cbb784f6678061"},
Expand Down
9 changes: 9 additions & 0 deletions priv/repo/migrations/20230927180316_add_image_embedding.exs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
defmodule Emoji.Repo.Migrations.AddImageEmbedding do
use Ecto.Migration

def change do
alter table(:predictions) do
add :image_embedding, :binary
end
end
end

0 comments on commit 9879e83

Please sign in to comment.