Skip to content

Commit

Permalink
add accept and reject button and interaction
Browse files Browse the repository at this point in the history
  • Loading branch information
EdenWuyifan committed May 28, 2024
1 parent 02447b5 commit 3815686
Show file tree
Hide file tree
Showing 2 changed files with 381 additions and 3,872 deletions.
134 changes: 115 additions & 19 deletions bdikit/visualization/scope_reducing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import logging

import altair as alt
import numpy as np
Expand All @@ -11,6 +12,8 @@
pn.extension("mathjax")
pn.extension("vega")

logger = logging.getLogger(__name__)


class SRHeatMapManager:
def __init__(self) -> None:
Expand All @@ -23,6 +26,9 @@ def __init__(self) -> None:
self.rec_cols_gdc = None
self.clusters = None

# Selected column
self.selected_row = None

def _load_json(self):
with open(self.json_path) as f:
data = json.load(f)
Expand All @@ -32,7 +38,8 @@ def _write_json(self, data):
with open(self.json_path, "w") as f:
json.dump(data, f)

def get_heatmap(self, recommendations):
def get_heatmap(self):
recommendations = self._load_json()
rec_cols = set()
rec_table = []
rec_list = []
Expand Down Expand Up @@ -124,6 +131,48 @@ def _get_column_values(self, properties):
else:
return None

def _accept_match(self, col_name=None, match_name=None):
if self.selected_row is None:
return
col_name = self.selected_row["Column"].values[0]
match_name = self.selected_row["Recommendation"].values[0]
recommendations = self._load_json()
for idx, d in enumerate(recommendations):
candidate_name = d["Candidate column"]
if candidate_name != col_name:
continue
for top_k_name, top_k_score in d["Top k columns"]:
if top_k_name == match_name:
recommendations[idx] = {
"Candidate column": candidate_name,
"Top k columns": [[top_k_name, top_k_score]],
}
self._write_json(recommendations)
self.get_heatmap()
return

def _reject_match(self):
if self.selected_row is None:
return
col_name = self.selected_row["Column"].values[0]
match_name = self.selected_row["Recommendation"].values[0]
recommendations = self._load_json()
for idx, d in enumerate(recommendations):
candidate_name = d["Candidate column"]
if candidate_name != col_name:
continue
new_top_k = []
for top_k_name, top_k_score in d["Top k columns"]:
if top_k_name != match_name:
new_top_k.append([top_k_name, top_k_score])
recommendations[idx] = {
"Candidate column": candidate_name,
"Top k columns": new_top_k,
}
self._write_json(recommendations)
self.get_heatmap()
return

def get_clusters(self):
words = self.rec_table_df["Column"].to_numpy()
lev_similarity = -1 * np.array(
Expand All @@ -136,19 +185,52 @@ def get_clusters(self):
)
affprop.fit(lev_similarity)

print(f"Number of clusters: {np.unique(affprop.labels_).shape[0]}\n")
logger.debug(f"Number of clusters: {np.unique(affprop.labels_).shape[0]}\n")
cluster_names = []
clusters = {}
for cluster_id in np.unique(affprop.labels_):
exemplar = words[affprop.cluster_centers_indices_[cluster_id]]
cluster = np.unique(words[np.nonzero(affprop.labels_ == cluster_id)])
cluster_str = ", ".join(cluster)
print(" - *%s:* %s" % (exemplar, cluster_str))
logger.debug(" - *%s:* %s" % (exemplar, cluster_str))
cluster_names.append(exemplar)
clusters[exemplar] = cluster
self.clusters = clusters

def _plot_heatmap(self, clusters=[], subschemas=[], threshold=0.5):
def _plot_heatmap_base(self, heatmap_rec_list):
single = alt.selection_point(name="single")
base = (
alt.Chart(heatmap_rec_list)
.mark_rect(size=100)
.encode(
y=alt.X("Column:O", sort=None),
x=alt.X(f"Recommendation:O", sort=None),
color=alt.condition(single, "Value:Q", alt.value("lightgray")),
# color="Value:Q",
tooltip=[
alt.Tooltip("Column", title="Column"),
alt.Tooltip("Recommendation", title="Recommendation"),
alt.Tooltip("Value", title="Value"),
],
)
.add_params(single)
)
return pn.pane.Vega(base)

def _plot_selected_row(self, heatmap_rec_list, selection):
if not selection:
return "## No selection"
selected_row = heatmap_rec_list.iloc[selection]
column = selected_row["Column"].values[0]
rec = selected_row["Recommendation"].values[0]
# value = selected_row["Value"]
# self._accept_match(column, rec)
self.selected_row = selected_row
return pn.widgets.DataFrame(selected_row)

def _plot_pane(
self, clusters=[], subschemas=[], threshold=0.5, acc_click=0, rej_click=0
):
heatmap_rec_list = self.rec_list_df[self.rec_list_df["Value"] >= threshold]
if clusters:
clustered_cols = []
Expand All @@ -165,21 +247,15 @@ def _plot_heatmap(self, clusters=[], subschemas=[], threshold=0.5):
heatmap_rec_list["Recommendation"].isin(subschema_rec_cols)
]

base = (
alt.Chart(heatmap_rec_list)
.mark_rect()
.encode(
y=alt.X("Column:O", sort=None),
x=alt.X(f"Recommendation:O", sort=None),
color="Value:Q",
tooltip=[
alt.Tooltip("Column", title="Column"),
alt.Tooltip("Recommendation", title="Recommendation"),
alt.Tooltip("Value", title="Value"),
],
)
heatmap_pane = self._plot_heatmap_base(heatmap_rec_list)
return pn.Column(
heatmap_pane,
pn.bind(
self._plot_selected_row,
heatmap_rec_list,
heatmap_pane.selection.param.single,
),
)
return pn.pane.Vega(base)

def plot_heatmap(self):
select_cluster = pn.widgets.MultiChoice(
Expand All @@ -192,15 +268,35 @@ def plot_heatmap(self):
name="Threshold", start=0, end=1.0, step=0.01, value=0.5, width=220
)

acc_button = pn.widgets.Button(name="Accept Match", button_type="success")

rej_button = pn.widgets.Button(name="Decline Match", button_type="danger")

def on_click_accept_match(event):
self._accept_match()

def on_click_reject_match(event):
self._reject_match()

acc_button.on_click(on_click_accept_match)
rej_button.on_click(on_click_reject_match)

heatmap_bind = pn.bind(
self._plot_heatmap, select_cluster, select_rec_groups, thresh_slider
self._plot_pane,
select_cluster,
select_rec_groups,
thresh_slider,
acc_button.param.clicks,
rej_button.param.clicks,
)

column_left = pn.Column(
"# Column",
select_cluster,
select_rec_groups,
thresh_slider,
acc_button,
rej_button,
styles=dict(background="WhiteSmoke"),
)

Expand Down
Loading

0 comments on commit 3815686

Please sign in to comment.