Skip to content

Commit

Permalink
[BUGFIX] argilla: review datasest import with new export flow (#5756)
Browse files Browse the repository at this point in the history
# Description
<!-- Please include a summary of the changes and the related issue.
Please also include relevant motivation and context. List any
dependencies that are required for this change. -->

This PR reviews and fixes error when importing datasets from exported
datasets

**Type of change**
<!-- Please delete options that are not relevant. Remember to title the
PR according to the type of change -->

- Bug fix (non-breaking change which fixes an issue)

**How Has This Been Tested**
<!-- Please add some reference about how your feature has been tested.
-->

**Checklist**
<!-- Please go over the list and make sure you've taken everything into
account -->

- I added relevant documentation
- I followed the style guidelines of this project
- I did a self-review of my code
- I made corresponding changes to the documentation
- I confirm My changes generate no new warnings
- I have added tests that prove my fix is effective or that my feature
works
- I have added relevant notes to the CHANGELOG.md file (See
https://keepachangelog.com/)
  • Loading branch information
frascuchon authored Dec 16, 2024
1 parent ebb0fa2 commit e3938e8
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 33 deletions.
1 change: 1 addition & 0 deletions argilla/src/argilla/datasets/_io/_disk.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def from_disk(
name=dataset_model.name, workspace_id=workspace.id
)
dataset = cls.from_model(model=dataset_model, client=client)
dataset.get()
else:
# Create a new dataset and load the settings and records
if not os.path.exists(settings_path):
Expand Down
30 changes: 22 additions & 8 deletions argilla/src/argilla/datasets/_io/_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,11 @@ def _log_dataset_records(hf_dataset: "HFDataset", dataset: "Dataset"):
for col in responses_columns:
question_name = col.split(".")[0]
if col.endswith("users"):
response_questions[question_name]["users"] = hf_dataset[col]
user_ids.update({UUID(user_id): UUID(user_id) for user_id in set(sum(hf_dataset[col], []))})
response_questions[question_name]["users"] = hf_dataset[col] or []
for users in hf_dataset[col]:
if users is None:
continue
user_ids.update({UUID(user_id): user_id for user_id in users})
elif col.endswith("responses"):
response_questions[question_name]["responses"] = hf_dataset[col]
elif col.endswith("status"):
Expand All @@ -240,7 +243,15 @@ def _log_dataset_records(hf_dataset: "HFDataset", dataset: "Dataset"):
user_ids[unknown_user_id] = my_user.id

# Create a mapper to map the Hugging Face dataset to a Record object
mapping = {col: col for col in hf_dataset.column_names if ".suggestion" in col}
mapping = {}
for col in hf_dataset.column_names:
if ".suggestion" in col:
mapping[col] = col
elif col.startswith("metadata.") and col.replace("metadata.", "") in dataset.schema:
mapping[col] = col.replace("metadata.", "")
elif col.startswith("vector.") and col.replace("vector.", "") in dataset.schema:
mapping[col] = col.replace("vector.", "")

mapper = IngestedRecordMapper(dataset=dataset, mapping=mapping, user_id=my_user.id)

# Extract responses and create Record objects
Expand All @@ -249,14 +260,17 @@ def _log_dataset_records(hf_dataset: "HFDataset", dataset: "Dataset"):
for idx, row in enumerate(hf_dataset):
record = mapper(row)
for question_name, values in response_questions.items():
response_values = values["responses"][idx]
response_users = values["users"][idx]
response_status = values["status"][idx]
response_values = values["responses"][idx] or []
response_users = values["users"][idx] or []
response_status = values["status"][idx] or []

used_users = set()
for value, user_id, status in zip(response_values, response_users, response_status):
user_id = user_ids[UUID(user_id)]
if user_id in response_users:
if user_id in used_users:
continue
response_users[user_id] = True

used_users.add(user_id)
response = Response(
user_id=user_id,
question_name=question_name,
Expand Down
20 changes: 14 additions & 6 deletions argilla/src/argilla/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,16 @@ def __init__(
status (Union[ResponseStatus, str]): The status of the response as "draft", "submitted", "discarded".
"""

if isinstance(status, str):
status = ResponseStatus(status)

if question_name is None:
raise ValueError("question_name is required")
if value is None:
if value is None and status == ResponseStatus.submitted:
raise ValueError("value is required")
if user_id is None:
raise ValueError("user_id is required")

if isinstance(status, str):
status = ResponseStatus(status)

self._record = _record
self.question_name = question_name
self.value = value
Expand Down Expand Up @@ -253,7 +253,7 @@ def _compute_user_id_from_responses(responses: List[Response]) -> Optional[UUID]
@staticmethod
def __responses_as_model_values(responses: List[Response]) -> Dict[str, Dict[str, Any]]:
"""Creates a dictionary of response values from a list of Responses"""
return {answer.question_name: {"value": answer.value} for answer in responses}
return {answer.question_name: {"value": answer.value} for answer in responses if answer.value is not None}

@classmethod
def __model_as_responses_list(cls, model: UserResponseModel, record: "Record") -> List[Response]:
Expand All @@ -276,4 +276,12 @@ def __ranking_from_model_value(cls, value: List[Dict[str, Any]]) -> List[str]:

@classmethod
def __ranking_to_model_value(cls, value: List[str]) -> List[Dict[str, str]]:
return [{"value": v} for v in value]
values = []
for v in value or []:
if isinstance(v, dict):
values.append(v)
elif isinstance(v, str):
values.append({"value": v})
else:
raise RecordResponsesError(f"Invalid value for ranking question: {v}")
return values
22 changes: 3 additions & 19 deletions argilla/src/argilla/settings/_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@
FieldSettings,
)
from argilla.settings._common import SettingsPropertyBase
from argilla.settings._metadata import MetadataField, MetadataType
from argilla.settings._vector import VectorField


try:
from typing import Self
Expand Down Expand Up @@ -296,21 +295,6 @@ def _field_from_model(model: FieldModel) -> Field:
raise ArgillaError(f"Unsupported field type: {model.settings.type}")


def _field_from_dict(data: dict) -> Union[Field, VectorField, MetadataType]:
def _field_from_dict(data: dict) -> Field:
"""Create a field instance from a field dictionary"""
field_type = data["type"]

if field_type == "text":
return TextField.from_dict(data)
elif field_type == "image":
return ImageField.from_dict(data)
elif field_type == "chat":
return ChatField.from_dict(data)
elif field_type == "custom":
return CustomField.from_dict(data)
elif field_type == "vector":
return VectorField.from_dict(data)
elif field_type == "metadata":
return MetadataField.from_dict(data)
else:
raise ArgillaError(f"Unsupported field type: {field_type}")
return _field_from_model(FieldModel(**data))

0 comments on commit e3938e8

Please sign in to comment.