Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge recent updates into deploy #368

Merged
merged 16 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 14 additions & 11 deletions backend/django/core/forms.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import copy
from io import StringIO

import pandas as pd
from django import forms
Expand Down Expand Up @@ -48,30 +47,32 @@ def read_data_file(data_file):
try:
if data_file.content_type == "text/tab-separated-values":
data = pd.read_csv(
StringIO(data_file.read().decode("utf8", "ignore")),
data_file,
sep="\t",
encoding="utf-8",
dtype=str,
).dropna(axis=0, how="all")
elif data_file.content_type == "text/csv":
data = pd.read_csv(
StringIO(data_file.read().decode("utf8", "ignore")),
data_file,
encoding="utf-8",
dtype=str,
).dropna(axis=0, how="all")
elif data_file.content_type.startswith(
"application/vnd"
) and data_file.name.endswith(".csv"):
data = pd.read_csv(
StringIO(data_file.read().decode("utf8", "ignore")),
data_file,
encoding="utf-8",
dtype=str,
).dropna(axis=0, how="all")
elif data_file.content_type.startswith(
"application/vnd"
) and data_file.name.endswith(".xlsx"):
data = (
pd.read_excel(data_file, dtype=str)
.dropna(axis=0, how="all")
.replace(r"\n", " ", regex=True)
)
data = pd.read_excel(
data_file,
dtype=str,
).dropna(axis=0, how="all")
else:
raise ValidationError(
"File type is not supported. Received {0} but only {1} are supported.".format(
Expand All @@ -81,13 +82,14 @@ def read_data_file(data_file):
except ParserError:
# If there was an error while parsing then raise invalid file error
raise ValidationError(
"Unable to read file. Please ensure it passes all the requirments"
"Unable to read file. Please ensure it passes all the requirements"
)
except UnicodeDecodeError:
# Some files are not in utf-8, let's just reject those.
raise ValidationError(
"Unable to read the file. Please ensure that the file is encoded in UTF-8."
)

return data


Expand Down Expand Up @@ -339,6 +341,7 @@ def clean(self):
self.cleaned_data["data"] = clean_data_helper(
data_df, labels, dedup_on, dedup_fields
)

return self.cleaned_data


Expand Down Expand Up @@ -393,7 +396,7 @@ def clean_label_data_file(self):
data_file = read_data_file(data)
return clean_label_data_helper(data_file)
except pd.errors.EmptyDataError:
return pd.DataFrame({"Label": []})
return None
else:
raise ValidationError("ERROR: no file provided")

Expand Down
10 changes: 7 additions & 3 deletions backend/django/core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,8 @@ class Meta:
value = models.TextField(null=True, blank=True)

def __str__(self):
if self.value is None:
return f"{str(self.metadata_field)}: "
return f"{str(self.metadata_field)}: {self.value}"


Expand Down Expand Up @@ -416,10 +418,12 @@ def __str__(self):
return self.field_name

def get_unique_options(self):
unique_list = list(
set(self.labelmetadata_set.all().values_list("value", flat=True))
unique_list = (
self.labelmetadata_set.all()
.order_by("value")
.values_list("value", flat=True)
.distinct()
)
unique_list.sort()
return unique_list


Expand Down
11 changes: 10 additions & 1 deletion backend/django/core/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
IRRLog,
Label,
LabelChangeLog,
MetaData,
Model,
Profile,
Project,
Expand Down Expand Up @@ -73,8 +74,16 @@ def to_representation(self, obj):
return base_representation


class MetaDataSerializer(serializers.ModelSerializer):
field_name = serializers.CharField(source="metadata_field.field_name")

class Meta:
model = MetaData
fields = ["field_name", "value"]


class DataSerializer(serializers.ModelSerializer):
metadata = serializers.StringRelatedField(many=True, read_only=True)
metadata = MetaDataSerializer(many=True, read_only=True)

class Meta:
model = Data
Expand Down
118 changes: 87 additions & 31 deletions backend/django/core/utils/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from django.db import connection, transaction
from django.utils import timezone
from sentence_transformers import SentenceTransformer
from django.core.exceptions import ValidationError

from core import tasks
from core.models import (
Expand Down Expand Up @@ -158,20 +159,15 @@ def create_data_from_csv(df, project):
df["project"] = project.pk
df["irr_ind"] = False

# Replace tabs since thats our delimiter, remove carriage returns since copy_from doesnt like them
# escape all backslashes because it seems to fix "end-of-copy marker corrupt"
df["Text"] = (
df["Text"]
.astype(str)
.apply(
lambda x: x.replace("\t", " ")
.replace("\r", " ")
.replace("\n", " ")
.replace("\\", "\\\\")
)
df.to_csv(
stream,
sep="\t",
header=False,
index=False,
columns=columns,
escapechar="\\",
doublequote=False,
)

df.to_csv(stream, sep="\t", header=False, index=False, columns=columns)
stream.seek(0)

with connection.cursor() as c:
Expand Down Expand Up @@ -206,6 +202,16 @@ def create_labels_from_csv(df, project):
stream = StringIO()

labels = {label.name: label.pk for label in project.labels.all()}

df["Label"] = df["Label"].apply(
lambda s: s.replace("\\n", "\n").replace("\\t", "\t").replace("\\r", "\r")
)

existing_labels = set(labels.keys())
df_labels = set(df["Label"].tolist())

quote_labels = df_labels - existing_labels
df["Label"] = df["Label"].apply(lambda s: f'"{s}"' if s in quote_labels else s)
df["data_id"] = df["hash"].apply(
lambda x: Data.objects.get(hash=x, project=project).pk
)
Expand All @@ -218,7 +224,15 @@ def create_labels_from_csv(df, project):
# these data are preloaded
df["pre_loaded"] = True

df.to_csv(stream, sep="\t", header=False, index=False, columns=columns)
df.to_csv(
stream,
sep="\t",
header=False,
index=False,
columns=columns,
escapechar="\\",
doublequote=False,
)
stream.seek(0)

with connection.cursor() as c:
Expand Down Expand Up @@ -250,6 +264,8 @@ def create_metadata_objects_from_csv(df, project):
header=False,
index=False,
columns=df_meta.columns.values.tolist(),
escapechar="\\",
doublequote=False,
)
stream.seek(0)
with connection.cursor() as c:
Expand Down Expand Up @@ -328,9 +344,9 @@ def add_data(project, df):

df["hash"] = ""
for f in dedup_on_fields:
df["hash"] += df[f].astype(str) + "_"
df["hash"] += df[f].fillna("None").astype(str) + "_"

df["hash"] += df["Text"].astype(str)
df["hash"] += df["Text"].fillna("None").astype(str)
df["hash"] = df["hash"].apply(md5_hash)

df.drop_duplicates(subset=["hash"], keep="first", inplace=True)
Expand Down Expand Up @@ -818,33 +834,73 @@ def get_unlabelled_data_objs(project_id: int) -> int:
return cursor.fetchone()[0]


def create_label_metadata(project, label_data, label_list):
def create_label_metadata(project, label_data):
"""This function creates LabelMetadataField objects for each new field and
LabelMetadata objects for each label-field pair.

Args:
project: a Project object
label_data: a pandas dataframe with the label metadata fields
label_list: a list of label objects for the project
"""
label_metadata = [c for c in label_data if c not in ["Label", "Description"]]
label_objects = pd.DataFrame(
list(Label.objects.filter(project=project).values("id", "name"))
).rename(columns={"name": "Label", "id": "label_id"})

# for some labels we will need to add quotes and un-escape strings for the merge to work
existing_label_ids = set(label_objects["Label"].tolist())
df_label_ids = set(label_data["Label"].tolist())

need_quotes = df_label_ids - existing_label_ids
label_data["Label"] = label_data["Label"].apply(
lambda s: (
f'"{s}"'.replace("\\n", "\n").replace("\\t", "\t").replace("\\r", "\r")
if s in need_quotes
else s
)
)
df_label_ids = set(label_data["Label"].tolist())
if len(df_label_ids - existing_label_ids) > 0:
raise ValidationError(
"ERROR loading in label metadata. Something is going wrong with the label file."
)

label_data = label_data.merge(label_objects, on="Label", how="inner")

label_metadata = [
c
for c in label_data
if c not in ["Label", "Description", "label_id", "project"]
]
if len(label_metadata) > 0:
for metadata_col in label_metadata:
field_name = str(metadata_col)
label_metadata_field = LabelMetaDataField.objects.create(
project=project, field_name=metadata_col
project=project, field_name=field_name
)
all_metadata_values = label_data[metadata_col].tolist()
all_label_metadata_objects = [
LabelMetaData(
label=label_list[i],
label_metadata_field=label_metadata_field,
value=all_metadata_values[i],
)
for i in range(len(all_metadata_values))
]
LabelMetaData.objects.bulk_create(
all_label_metadata_objects, batch_size=8000
df_meta = label_data[["label_id", field_name]].rename(
columns={field_name: "value"}
)
df_meta["label_metadata_field_id"] = label_metadata_field.pk
df_meta = df_meta[["label_id", "label_metadata_field_id", "value"]]
stream = StringIO()
df_meta.to_csv(
stream,
sep="\t",
header=False,
index=False,
columns=df_meta.columns.values.tolist(),
escapechar="\\",
doublequote=False,
)
stream.seek(0)
with connection.cursor() as c:
c.copy_from(
stream,
LabelMetaData._meta.db_table,
sep="\t",
null="",
columns=df_meta.columns.values.tolist(),
)


def create_or_update_project_category(project, new_category):
Expand Down
37 changes: 31 additions & 6 deletions backend/django/core/utils/utils_form.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,22 @@ def clean_data_helper(
)

labels_in_data = data["Label"].dropna(inplace=False).unique()
if len(labels_in_data) > 0 and len(set(labels_in_data) - set(supplied_labels)) > 0:
if (
(supplied_labels is not None)
and len(labels_in_data) > 0
and len(set(labels_in_data) - set(supplied_labels)) > 0
):
just_in_data = set(labels_in_data) - set(supplied_labels)
raise ValidationError(
f"There are extra labels in the file which were not created in step 2: {just_in_data}"
)
# add a correction for label descriptions with weird characters
labels_in_data_fixed = [
f'"{s}"'.replace("\\n", "\n").replace("\\t", "\t").replace("\\r", "\r")
for s in just_in_data
] + list(set(labels_in_data) - just_in_data)

if len(set(labels_in_data_fixed) - set(supplied_labels)) > 0:
raise ValidationError(
f"There are extra labels in the file which were not in step 2 of project creation: {just_in_data}"
)

if "ID" in data.columns:
# there should be no null values
Expand Down Expand Up @@ -107,8 +118,22 @@ def clean_label_data_helper(data, existing_labels=[]):
if field not in data.columns:
raise ValidationError(f"File is missing required field '{field}'.")

new_labels = list(set(data["Label"].unique()) - set(existing_labels))
if len(new_labels) > 0 and len(existing_labels) > 0:
new_labels_all = set(data["Label"].unique())
new_labels = list(new_labels_all - set(existing_labels))

# # try adding quotes around the "new" labels and see if they match now
if len(existing_labels) > 0 and len(new_labels) > 0:
data["Label"] = data["Label"].apply(
lambda s: (
f'"{s}"'.replace("\\n", "\n").replace("\\t", "\t").replace("\\r", "\r")
if s in new_labels
else s
)
)
new_labels_all = set(data["Label"].unique())
fixed_labels = list(new_labels_all - set(existing_labels))

if len(new_labels) > 0 and len(existing_labels) > 0 and len(fixed_labels) > 0:
raise ValidationError(
f"New labels were found in this file: {', '.join(new_labels)}"
)
Expand Down
2 changes: 1 addition & 1 deletion backend/django/core/views/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def download_irr_log(request, project_pk):
for log in logs:
label_name = log.label.name if log.label else ""
writer.writerow(
[log.data.pk, log.data.text, label_name, log.profile.user, log.timestamp]
[log.data.upload_id, log.data.text, label_name, log.profile.user, log.timestamp]
)

return response
Expand Down
Loading