From 4c5e606f18af578e9a3574a53d3beb3a42715e9a Mon Sep 17 00:00:00 2001 From: Charlie Holtz Date: Wed, 27 Sep 2023 16:12:10 -0400 Subject: [PATCH] add search via image description, and error handling --- config/dev.exs | 1 + lib/emoji/embeddings.ex | 10 +++- lib/emoji/embeddings/index.ex | 52 ++++++++++----------- lib/emoji/embeddings/worker.ex | 5 ++ lib/emoji/predictions.ex | 8 +++- lib/emoji_web/components/core_components.ex | 11 +++++ lib/emoji_web/live/search_live.ex | 31 +++++++++--- lib/emoji_web/live/search_live.html.heex | 21 +++++++-- 8 files changed, 99 insertions(+), 40 deletions(-) diff --git a/config/dev.exs b/config/dev.exs index 45bb862..b40d3e8 100644 --- a/config/dev.exs +++ b/config/dev.exs @@ -1,6 +1,7 @@ import Config config :emoji, env: :dev +config :phoenix_live_view, :debug_heex_annotations, true # Configure your database config :emoji, Emoji.Repo, diff --git a/lib/emoji/embeddings.ex b/lib/emoji/embeddings.ex index 4141d1c..ae70f00 100644 --- a/lib/emoji/embeddings.ex +++ b/lib/emoji/embeddings.ex @@ -37,7 +37,7 @@ defmodule Emoji.Embeddings do "data:#{mime_type};base64,#{base64}" end - def search_emojis(query, num_results \\ 9) do + def search_emojis(query, num_results \\ 9, via_images \\ false) do embedding = create( query, @@ -45,8 +45,14 @@ defmodule Emoji.Embeddings do ) |> Nx.from_binary(:f32) + IO.inspect(via_images, label: "via images") + %{labels: labels, distances: distances} = - Emoji.Embeddings.Index.search(embedding, num_results) + if via_images do + Emoji.Embeddings.Index.search_images(embedding, num_results) + else + Emoji.Embeddings.Index.search(embedding, num_results) + end ids = Nx.to_flat_list(labels) distances = Nx.to_flat_list(distances) diff --git a/lib/emoji/embeddings/index.ex b/lib/emoji/embeddings/index.ex index 5c41f34..ec8884f 100644 --- a/lib/emoji/embeddings/index.ex +++ b/lib/emoji/embeddings/index.ex @@ -11,45 +11,45 @@ defmodule Emoji.Embeddings.Index do end def init(_args) do - {:ok, index} = HNSWLib.Index.new(:l2, 1024, 100_000) - - Emoji.Predictions.list_predictions_with_embeddings() - |> Enum.reduce(index, fn prediction, index -> - id = prediction.id - embedding = prediction.embedding - # image_embedding = prediction.image_embedding - - HNSWLib.Index.add_items(index, Nx.from_binary(embedding, :f32), ids: Nx.tensor([id])) - - # if image_embedding != nil do - # HNSWLib.Index.add_items(index, Nx.from_binary(image_embedding, :f32), - # ids: Nx.tensor([id]) - # ) - # end + {:ok, full_index} = HNSWLib.Index.new(:l2, 1024, 100_000) + {:ok, image_index} = HNSWLib.Index.new(:l2, 1024, 10_000) + + Emoji.Predictions.list_predictions_with_text_embeddings() + |> Enum.each(fn prediction -> + HNSWLib.Index.add_items(full_index, Nx.from_binary(prediction.embedding, :f32), + ids: Nx.tensor([prediction.id]) + ) + end) - index + Emoji.Predictions.list_predictions_with_image_embeddings() + |> Enum.each(fn prediction -> + HNSWLib.Index.add_items(image_index, Nx.from_binary(prediction.image_embedding, :f32), + ids: Nx.tensor([prediction.id]) + ) end) Logger.info("Index successfully created") - {:ok, index} - end - - def add(id, embedding) do - GenServer.cast(@me, {:add, id, embedding}) + {:ok, %{full_index: full_index, image_index: image_index}} end def search(embedding, k) do + Logger.info("Searching text") GenServer.call(@me, {:search, embedding, k}, 15_000) end - def handle_cast({:add, id, embedding}, index) do - HNSWLib.Index.add_items(index, Nx.from_binary(embedding, :f32), ids: Nx.tensor([id])) - {:noreply, index} + def search_images(embedding, k) do + Logger.info("Searching images") + GenServer.call(@me, {:search_images, embedding, k}, 15_000) + end + + def handle_call({:search, embedding, k}, _from, %{full_index: index} = index_dict) do + {:ok, labels, dists} = HNSWLib.Index.knn_query(index, embedding, k: k) + {:reply, %{labels: labels, distances: dists}, index_dict} end - def handle_call({:search, embedding, k}, _from, index) do + def handle_call({:search_images, embedding, k}, _from, %{image_index: index} = index_dict) do {:ok, labels, dists} = HNSWLib.Index.knn_query(index, embedding, k: k) - {:reply, %{labels: labels, distances: dists}, index} + {:reply, %{labels: labels, distances: dists}, index_dict} end def terminate(reason, _state) do diff --git a/lib/emoji/embeddings/worker.ex b/lib/emoji/embeddings/worker.ex index 98057df..c9ec85e 100644 --- a/lib/emoji/embeddings/worker.ex +++ b/lib/emoji/embeddings/worker.ex @@ -55,6 +55,11 @@ defmodule Emoji.Embeddings.Worker do prediction end + defp create_image_embedding(%{id: id, no_bg_output: nil}) do + Logger.info("No url, skipping #{id}") + nil + end + defp create_image_embedding(prediction) do Logger.info("Creating image embeddings for #{prediction.id}") diff --git a/lib/emoji/predictions.ex b/lib/emoji/predictions.ex index 4d24bf9..2388da4 100644 --- a/lib/emoji/predictions.ex +++ b/lib/emoji/predictions.ex @@ -57,10 +57,16 @@ defmodule Emoji.Predictions do Repo.all(Prediction) end - def list_predictions_with_embeddings do + def list_predictions_with_text_embeddings do Repo.all(from p in Prediction, where: not is_nil(p.embedding) and not is_nil(p.emoji_output)) end + def list_predictions_with_image_embeddings do + Repo.all( + from p in Prediction, where: not is_nil(p.image_embedding) and not is_nil(p.emoji_output) + ) + end + def list_firehose_predictions() do Repo.all( from p in Prediction, diff --git a/lib/emoji_web/components/core_components.ex b/lib/emoji_web/components/core_components.ex index 2fd4317..d8a6945 100644 --- a/lib/emoji_web/components/core_components.ex +++ b/lib/emoji_web/components/core_components.ex @@ -13,6 +13,17 @@ defmodule EmojiWeb.CoreComponents do alias Phoenix.LiveView.JS import EmojiWeb.Gettext + @doc """ + Badge + """ + attr :text, :string, required: true + + def badge(assigns) do + ~H""" + <%= @text %> + """ + end + @doc """ Renders a modal. diff --git a/lib/emoji_web/live/search_live.ex b/lib/emoji_web/live/search_live.ex index ce535a8..3df7213 100644 --- a/lib/emoji_web/live/search_live.ex +++ b/lib/emoji_web/live/search_live.ex @@ -3,19 +3,38 @@ defmodule EmojiWeb.SearchLive do @impl true def mount(_params, _session, socket) do - {:ok, socket |> assign(results: [], query: nil, loading: false)} + {:ok, + socket + |> assign( + results: [], + query: nil, + loading: false, + form: to_form(%{"query" => nil, "search_via_images" => false}) + )} end @impl true - def handle_event("search", %{"query" => query}, socket) do - {:noreply, push_patch(socket, to: ~p"/experimental-search?q=#{query}")} + def handle_event( + "search", + %{"query" => query, "search_via_images" => search_via_images}, + socket + ) do + {:noreply, + push_patch(socket, + to: ~p"/experimental-search?q=#{query}&search_via_images=#{search_via_images}" + )} end @impl true - def handle_params(%{"q" => query}, _uri, socket) do - Task.async(fn -> Emoji.Embeddings.search_emojis(query, 21) end) + def handle_params(%{"q" => query, "search_via_images" => search_via_images}, _uri, socket) do + Task.async(fn -> Emoji.Embeddings.search_emojis(query, 3, search_via_images == "true") end) - {:noreply, socket |> assign(loading: true) |> assign(query: query)} + {:noreply, + socket + |> assign( + loading: true, + form: to_form(%{"query" => query, "search_via_images" => search_via_images}) + )} end def handle_params(_params, _uri, socket) do diff --git a/lib/emoji_web/live/search_live.html.heex b/lib/emoji_web/live/search_live.html.heex index b4c550f..63a0a50 100644 --- a/lib/emoji_web/live/search_live.html.heex +++ b/lib/emoji_web/live/search_live.html.heex @@ -1,24 +1,35 @@
<.back navigate={~p"/"}>Back - +
+ <.input + field={@form[:search_via_images]} + type="checkbox" + label="Search via image description (experimental)" + id="exact" + class="ml-2" + /> +
Searching...