Skip to content

Commit

Permalink
switch between embeddings collections in the demo UI
Browse files Browse the repository at this point in the history
  • Loading branch information
metazool committed Oct 2, 2024
1 parent 03ff2bf commit 2867615
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 18 deletions.
42 changes: 30 additions & 12 deletions src/cyto_ml/visualisation/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
"""

import os
import random
from io import BytesIO
from typing import Optional
Expand All @@ -21,42 +20,47 @@
from dotenv import load_dotenv
from PIL import Image

from cyto_ml.data.vectorstore import embeddings, vector_store
from cyto_ml.data.vectorstore import client, embeddings, vector_store

load_dotenv()


def collections() -> list:
return [c.name for c in client.list_collections()]


@st.cache_resource
def store() -> None:
def store(coll: str) -> None:
"""
Load the vector store with image embeddings.
TODO switch between different collections, not set in .env
Set as "EMBEDDINGS" in .env or defaults to "plankton"
"""
return vector_store(os.environ.get("EMBEDDINGS", "plankton"))
return vector_store(coll)


@st.cache_data
def image_ids() -> list:
def image_ids(coll: str) -> list:
"""
Retrieve image embeddings from chroma database.
TODO Revisit our available metadata
"""
result = store().get()
result = store(coll).get()
return result["ids"]


@st.cache_data
def image_embeddings() -> list:
return embeddings(store())
return embeddings(store(st.session_state["collection"]))


def closest_n(url: str, n: Optional[int] = 26) -> list:
"""
Given an image URL return the N closest ones by cosine distance
"""
embed = store().get([url], include=["embeddings"])["embeddings"]
results = store().query(query_embeddings=embed, n_results=n)
s = store(st.session_state["collection"])
embed = s.get([url], include=["embeddings"])["embeddings"]
results = s.query(query_embeddings=embed, n_results=n)
return results["ids"][0] # by index because API assumes query always multiple


Expand All @@ -70,7 +74,11 @@ def cached_image(url: str) -> Image:
response = requests.get(url)
image = Image.open(BytesIO(response.content))
if image.mode == "I;16":
image.point(lambda p: p * 0.0039063096, mode="RGB")
# 16 bit greyscale - divide by 255, convert RGB for display
(_, max_val) = image.getextrema()
image.point(lambda p: p * 1 / max_val)
# image.point(lambda p: p * (1/255))#.convert('RGB')
# image.mode = 'I'#, mode="RGB")
image = image.convert("RGB")
return image

Expand Down Expand Up @@ -124,7 +132,7 @@ def create_figure(df: pd.DataFrame) -> go.Figure:


def random_image() -> str:
ids = image_ids()
ids = image_ids(st.session_state["collection"])
# starting image
test_image_url = random.choice(ids)
return test_image_url
Expand All @@ -148,12 +156,22 @@ def main() -> None:
if "random_img" not in st.session_state:
st.session_state["random_img"] = None

colls = collections()
if "collection" not in st.session_state:
st.session_state["collection"] = colls[0]

st.set_page_config(layout="wide", page_title="Plankton image embeddings")

st.title("Image embeddings")
st.write(f"{len(image_ids())} images in this collection")
st.write(f"{len(image_ids(st.session_state['collection']))} images in {st.session_state["collection"]}")
# the generated HTML is not lovely at all

st.selectbox(
"image collection",
colls,
key="collection",
)

st.session_state["random_img"] = random_image()
show_random_image()

Expand Down
20 changes: 14 additions & 6 deletions src/cyto_ml/visualisation/pages/02_kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,7 @@
import streamlit as st
from sklearn.cluster import KMeans

from cyto_ml.visualisation.app import (
cached_image,
image_embeddings,
image_ids,
)
from cyto_ml.visualisation.app import cached_image, collections, image_embeddings, image_ids

logging.basicConfig(level=logging.INFO)

Expand All @@ -21,6 +17,7 @@ def kmeans_cluster() -> KMeans:
"""
X = image_embeddings()
logging.info(st.session_state["n_clusters"])
n_clusters = st.session_state["n_clusters"]
# Initialize and fit KMeans
kmeans = KMeans(n_clusters=n_clusters, random_state=42)
Expand All @@ -36,7 +33,7 @@ def image_labels() -> dict:
km = kmeans_cluster()
clusters = dict(zip(set(km.labels_), [[] for _ in range(len(set(km.labels_)))]))

for index, _id in enumerate(image_ids()):
for index, _id in enumerate(image_ids(st.session_state["collection"])):
label = km.labels_[index]
clusters[label].append(_id)
return clusters
Expand Down Expand Up @@ -73,6 +70,17 @@ def show_cluster() -> None:

# TODO some visualisation, actual content, etc
def main() -> None:
# duplicate logic from main page, how should this state be shared?

colls = collections()
if "collection" not in st.session_state:
st.session_state["collection"] = colls[0]
st.selectbox(
"image collection",
colls,
key="collection",
)

# start with this cluster label
if "cluster" not in st.session_state:
st.session_state["cluster"] = 1
Expand Down

0 comments on commit 2867615

Please sign in to comment.