Skip to content

Commit

Permalink
add search via image description, and error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
cbh123 committed Sep 27, 2023
1 parent 5303c71 commit 4c5e606
Show file tree
Hide file tree
Showing 8 changed files with 99 additions and 40 deletions.
1 change: 1 addition & 0 deletions config/dev.exs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import Config

config :emoji, env: :dev
config :phoenix_live_view, :debug_heex_annotations, true

# Configure your database
config :emoji, Emoji.Repo,
Expand Down
10 changes: 8 additions & 2 deletions lib/emoji/embeddings.ex
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,22 @@ defmodule Emoji.Embeddings do
"data:#{mime_type};base64,#{base64}"
end

def search_emojis(query, num_results \\ 9) do
def search_emojis(query, num_results \\ 9, via_images \\ false) do
embedding =
create(
query,
"daanelson/imagebind:0383f62e173dc821ec52663ed22a076d9c970549c209666ac3db181618b7a304"
)
|> Nx.from_binary(:f32)

IO.inspect(via_images, label: "via images")

%{labels: labels, distances: distances} =
Emoji.Embeddings.Index.search(embedding, num_results)
if via_images do
Emoji.Embeddings.Index.search_images(embedding, num_results)
else
Emoji.Embeddings.Index.search(embedding, num_results)
end

ids = Nx.to_flat_list(labels)
distances = Nx.to_flat_list(distances)
Expand Down
52 changes: 26 additions & 26 deletions lib/emoji/embeddings/index.ex
Original file line number Diff line number Diff line change
Expand Up @@ -11,45 +11,45 @@ defmodule Emoji.Embeddings.Index do
end

def init(_args) do
{:ok, index} = HNSWLib.Index.new(:l2, 1024, 100_000)

Emoji.Predictions.list_predictions_with_embeddings()
|> Enum.reduce(index, fn prediction, index ->
id = prediction.id
embedding = prediction.embedding
# image_embedding = prediction.image_embedding

HNSWLib.Index.add_items(index, Nx.from_binary(embedding, :f32), ids: Nx.tensor([id]))

# if image_embedding != nil do
# HNSWLib.Index.add_items(index, Nx.from_binary(image_embedding, :f32),
# ids: Nx.tensor([id])
# )
# end
{:ok, full_index} = HNSWLib.Index.new(:l2, 1024, 100_000)
{:ok, image_index} = HNSWLib.Index.new(:l2, 1024, 10_000)

Emoji.Predictions.list_predictions_with_text_embeddings()
|> Enum.each(fn prediction ->
HNSWLib.Index.add_items(full_index, Nx.from_binary(prediction.embedding, :f32),
ids: Nx.tensor([prediction.id])
)
end)

index
Emoji.Predictions.list_predictions_with_image_embeddings()
|> Enum.each(fn prediction ->
HNSWLib.Index.add_items(image_index, Nx.from_binary(prediction.image_embedding, :f32),
ids: Nx.tensor([prediction.id])
)
end)

Logger.info("Index successfully created")
{:ok, index}
end

def add(id, embedding) do
GenServer.cast(@me, {:add, id, embedding})
{:ok, %{full_index: full_index, image_index: image_index}}
end

def search(embedding, k) do
Logger.info("Searching text")
GenServer.call(@me, {:search, embedding, k}, 15_000)
end

def handle_cast({:add, id, embedding}, index) do
HNSWLib.Index.add_items(index, Nx.from_binary(embedding, :f32), ids: Nx.tensor([id]))
{:noreply, index}
def search_images(embedding, k) do
Logger.info("Searching images")
GenServer.call(@me, {:search_images, embedding, k}, 15_000)
end

def handle_call({:search, embedding, k}, _from, %{full_index: index} = index_dict) do
{:ok, labels, dists} = HNSWLib.Index.knn_query(index, embedding, k: k)
{:reply, %{labels: labels, distances: dists}, index_dict}
end

def handle_call({:search, embedding, k}, _from, index) do
def handle_call({:search_images, embedding, k}, _from, %{image_index: index} = index_dict) do
{:ok, labels, dists} = HNSWLib.Index.knn_query(index, embedding, k: k)
{:reply, %{labels: labels, distances: dists}, index}
{:reply, %{labels: labels, distances: dists}, index_dict}
end

def terminate(reason, _state) do
Expand Down
5 changes: 5 additions & 0 deletions lib/emoji/embeddings/worker.ex
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ defmodule Emoji.Embeddings.Worker do
prediction
end

defp create_image_embedding(%{id: id, no_bg_output: nil}) do
Logger.info("No url, skipping #{id}")
nil
end

defp create_image_embedding(prediction) do
Logger.info("Creating image embeddings for #{prediction.id}")

Expand Down
8 changes: 7 additions & 1 deletion lib/emoji/predictions.ex
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,16 @@ defmodule Emoji.Predictions do
Repo.all(Prediction)
end

def list_predictions_with_embeddings do
def list_predictions_with_text_embeddings do
Repo.all(from p in Prediction, where: not is_nil(p.embedding) and not is_nil(p.emoji_output))
end

def list_predictions_with_image_embeddings do
Repo.all(
from p in Prediction, where: not is_nil(p.image_embedding) and not is_nil(p.emoji_output)
)
end

def list_firehose_predictions() do
Repo.all(
from p in Prediction,
Expand Down
11 changes: 11 additions & 0 deletions lib/emoji_web/components/core_components.ex
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,17 @@ defmodule EmojiWeb.CoreComponents do
alias Phoenix.LiveView.JS
import EmojiWeb.Gettext

@doc """
Badge
"""
attr :text, :string, required: true

def badge(assigns) do
~H"""
<span class="inline-flex items-center rounded-md bg-gray-50 px-2 py-1 text-xs font-medium text-gray-600 ring-1 ring-inset ring-gray-500/10"><%= @text %></span>
"""
end

@doc """
Renders a modal.
Expand Down
31 changes: 25 additions & 6 deletions lib/emoji_web/live/search_live.ex
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,38 @@ defmodule EmojiWeb.SearchLive do

@impl true
def mount(_params, _session, socket) do
{:ok, socket |> assign(results: [], query: nil, loading: false)}
{:ok,
socket
|> assign(
results: [],
query: nil,
loading: false,
form: to_form(%{"query" => nil, "search_via_images" => false})
)}
end

@impl true
def handle_event("search", %{"query" => query}, socket) do
{:noreply, push_patch(socket, to: ~p"/experimental-search?q=#{query}")}
def handle_event(
"search",
%{"query" => query, "search_via_images" => search_via_images},
socket
) do
{:noreply,
push_patch(socket,
to: ~p"/experimental-search?q=#{query}&search_via_images=#{search_via_images}"
)}
end

@impl true
def handle_params(%{"q" => query}, _uri, socket) do
Task.async(fn -> Emoji.Embeddings.search_emojis(query, 21) end)
def handle_params(%{"q" => query, "search_via_images" => search_via_images}, _uri, socket) do
Task.async(fn -> Emoji.Embeddings.search_emojis(query, 3, search_via_images == "true") end)

{:noreply, socket |> assign(loading: true) |> assign(query: query)}
{:noreply,
socket
|> assign(
loading: true,
form: to_form(%{"query" => query, "search_via_images" => search_via_images})
)}
end

def handle_params(_params, _uri, socket) do
Expand Down
21 changes: 16 additions & 5 deletions lib/emoji_web/live/search_live.html.heex
Original file line number Diff line number Diff line change
@@ -1,24 +1,35 @@
<div>
<.back navigate={~p"/"}>Back</.back>
<form class="mt-4" name="emoji-search" id="emoji-search" phx-submit="search">
<.form for={@form} class="mt-4" name="emoji-search" id="emoji-search" phx-submit="search">
<label for="search" class="block text-sm font-medium text-gray-700">Emoji Search</label>
<div class="relative mt-1 flex items-center">
<input
<.input
type="text"
name="query"
value={@query}
field={@form[:query]}
id="query"
required="true"
class="block w-full rounded-md border-gray-300 pr-12 shadow-sm focus:border-indigo-500 focus:ring-indigo-500 sm:text-sm"
/>
</div>
</form>
<br />
<.input
field={@form[:search_via_images]}
type="checkbox"
label="Search via image description (experimental)"
id="exact"
class="ml-2"
/>
</.form>

<div :if={@loading} class="mt-2 animate-pulse">Searching...</div>

<ul :if={not @loading} role="list" class="mt-4 gap-6 grid grid-cols-3 divide-y divide-gray-200">
<li :for={{prediction, _distance} <- @results}>
<li :for={{prediction, distance} <- @results}>
<EmojiWeb.Components.emoji id={prediction.id} prediction={prediction} />
<div class="mt-2">
<.badge text={"Search distance: #{distance |> round()}"} />
</div>
</li>
</ul>
</div>

0 comments on commit 4c5e606

Please sign in to comment.