From 4f3a6e06ac318208102caf79b49cc5f79b945ecd Mon Sep 17 00:00:00 2001 From: Sooah Lee Date: Fri, 13 Sep 2024 10:53:26 +0900 Subject: [PATCH] Mergeback 1.9.0 to develop (#1604) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Summary ### How to test ### Checklist - [ ] I have added unit tests to cover my changes.​ - [ ] I have added integration tests to cover my changes.​ - [ ] I have added the description of my changes into [CHANGELOG](https://github.com/openvinotoolkit/datumaro/blob/develop/CHANGELOG.md).​ - [ ] I have updated the [documentation](https://github.com/openvinotoolkit/datumaro/tree/develop/docs) accordingly ### License - [ ] I submit _my code changes_ under the same [MIT License](https://github.com/openvinotoolkit/datumaro/blob/develop/LICENSE) that covers the project. Feel free to contact the maintainers if that's a concern. - [ ] I have updated the license header for each file (see an example below). ```python # Copyright (C) 2024 Intel Corporation # # SPDX-License-Identifier: MIT ``` --------- Co-authored-by: Yunchu Lee Co-authored-by: Wonju Lee --- CHANGELOG.md | 2 + requirements-core.txt | 3 + setup.py | 2 +- src/datumaro/plugins/framework_converter.py | 51 +++- tests/unit/test_framework_converter.py | 244 +++++++++++++++++++- 5 files changed, 291 insertions(+), 11 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f705d5739c..161b2d54e9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### New features - Add a new CLI command: datum format () +- Support language dataset for DmTorchDataset + () ### Enhancements - Change _Shape to Shape and add comments for subclasses of Shape diff --git a/requirements-core.txt b/requirements-core.txt index 1d2ce11bf3..078171ef59 100644 --- a/requirements-core.txt +++ b/requirements-core.txt @@ -64,3 +64,6 @@ json-stream # TabularValidator nltk + +# torch converter for language +portalocker diff --git a/setup.py b/setup.py index acc6925fdc..91b1b51e8c 100644 --- a/setup.py +++ b/setup.py @@ -85,7 +85,7 @@ def parse_requirements(filename=CORE_REQUIREMENTS_FILE): extras_require={ "tf": ["tensorflow"], "tfds": ["tensorflow-datasets<4.9.3"], - "torch": ["torch", "torchvision"], + "torch": ["torch", "torchvision", "torchtext==0.16.0"], "default": DEFAULT_REQUIREMENTS, }, ext_modules=ext_modules, diff --git a/src/datumaro/plugins/framework_converter.py b/src/datumaro/plugins/framework_converter.py index 556005e1b7..e5a5b7f6c2 100644 --- a/src/datumaro/plugins/framework_converter.py +++ b/src/datumaro/plugins/framework_converter.py @@ -1,4 +1,4 @@ -# Copyright (C) 2023 Intel Corporation +# Copyright (C) 2023-2024 Intel Corporation # # SPDX-License-Identifier: MIT @@ -17,6 +17,7 @@ "detection": AnnotationType.bbox, "instance_segmentation": AnnotationType.polygon, "semantic_segmentation": AnnotationType.mask, + "tabular": [AnnotationType.label, AnnotationType.caption], } @@ -88,7 +89,10 @@ def _gen_item(self, idx: int): if ann.type == TASK_ANN_TYPE[self.task] ] label = mask_tools.merge_masks((mask, label_id) for mask, label_id in masks) - + elif self.task == "tabular": + label = [ + ann.as_dict() for ann in item.annotations if ann.type in TASK_ANN_TYPE[self.task] + ] return image, label @@ -103,15 +107,58 @@ def __init__( task: str, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, + target: Optional[str] = None, + tokenizer: Optional[tuple[Callable, Callable]] = None, + vocab: Optional[tuple[Callable, Callable]] = None, ): super().__init__(dataset=dataset, subset=subset, task=task) self.transform = transform self.target_transform = target_transform + if self.task == "tabular": + if not isinstance(target, dict): + raise ValueError( + "Target should be a dictionary with 'input' and 'output' keys." + ) + self.input_target = target.get("input") + self.output_target = target.get("output") + if not self.input_target: + raise ValueError( + "Please provide target column for tabular task which is used for input" + ) + + if not (tokenizer and vocab): + raise ValueError("Both tokenizer and vocab must be provided for tabular task") + self.tokenizer = tokenizer + self.vocab = vocab + def __getitem__(self, idx): image, label = self._gen_item(idx) + if self.task == "tabular": + text = image()[self.input_target] + + if self.output_target: + src_tokenizer, tgt_tokenizer = self.tokenizer + src_vocab, tgt_vocab = self.vocab + src_tokens = src_tokenizer(text) + src_token_ids = src_vocab(src_tokens) + + label_text = label[0]["caption"].split(f"{self.output_target}:")[-1] + tgt_tokens = tgt_tokenizer(label_text) + tgt_token_ids = tgt_vocab(tgt_tokens) + + return torch.tensor(src_token_ids, dtype=torch.long), torch.tensor( + tgt_token_ids, dtype=torch.long + ) + else: + tokens = self.tokenizer(text) + token_ids = self.vocab(tokens) + return torch.tensor(token_ids, dtype=torch.long), torch.tensor( + label[0]["label"], dtype=torch.long + ) + if len(image.shape) == 2: image = np.expand_dims(image, axis=-1) diff --git a/tests/unit/test_framework_converter.py b/tests/unit/test_framework_converter.py index 0933884293..83fd9a97c5 100644 --- a/tests/unit/test_framework_converter.py +++ b/tests/unit/test_framework_converter.py @@ -1,4 +1,4 @@ -# Copyright (C) 2023 Intel Corporation +# Copyright (C) 2023-2024 Intel Corporation # # SPDX-License-Identifier: MIT @@ -13,14 +13,16 @@ from datumaro.components.annotation import ( AnnotationType, Bbox, + Caption, Label, LabelCategories, Mask, Polygon, + Tabular, ) from datumaro.components.dataset import Dataset from datumaro.components.dataset_base import DatasetItem -from datumaro.components.media import Image +from datumaro.components.media import Image, Table, TableRow from datumaro.plugins.framework_converter import ( TASK_ANN_TYPE, DmTfDataset, @@ -36,6 +38,8 @@ try: import torch + from torchtext.data.utils import get_tokenizer + from torchtext.vocab import build_vocab_from_iterator from torchvision import datasets, transforms except ImportError: TORCH_AVAILABLE = False @@ -142,6 +146,89 @@ def fxt_dataset(): ) +@pytest.fixture +def fxt_tabular_label_dataset(): + table = Table.from_list( + [ + { + "label": 1, + "text": "I rented I AM CURIOUS-YELLOW from my video store because of all the controversy that surrounded it when it was first released in 1967. I also heard that at first it was seized by U.S. customs if it ever tried to enter this country, therefore being a fan of films considered " + "controversial" + " I really had to see this for myself.

The plot is centered around a young Swedish drama student named Lena who wants to learn everything she can about life. In particular she wants to focus her attentions to making some sort of documentary on what the average Swede thought about certain political issues such as the Vietnam War and race issues in the United States. In between asking politicians and ordinary denizens of Stockholm about their opinions on politics, she has sex with her drama teacher, classmates, and married men.

What kills me about I AM CURIOUS-YELLOW is that 40 years ago, this was considered pornographic. Really, the sex and nudity scenes are few and far between, even then it's not shot like some cheaply made porno. While my countrymen mind find it shocking, in reality sex and nudity are a major staple in Swedish cinema. Even Ingmar Bergman, arguably their answer to good old boy John Ford, had sex scenes in his films.

I do commend the filmmakers for the fact that any sex shown in the film is shown for artistic purposes rather than just to shock people and make money to be shown in pornographic theaters in America. I AM CURIOUS-YELLOW is a good film for anyone wanting to study the meat and potatoes (no pun intended) of Swedish cinema. But really, this film doesn't have much of a plot.", + } + ] + ) + return Dataset.from_iterable( + [ + DatasetItem( + id=0, + subset="train", + media=TableRow(table=table, index=0), + annotations=[Label(id=0, attributes={}, group=0, object_id=-1, label=0)], + ) + ], + categories={ + AnnotationType.label: LabelCategories.from_iterable( + [("label:1", "label"), ("label:2", "label")] + ) + }, + media_type=TableRow, + ) + + +@pytest.fixture +def fxt_tabular_caption_dataset(): + table = Table.from_list( + [ + { + "source": "Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.", + "target": "Two young, White males are outside near many bushes.", + } + ] + ) + return Dataset.from_iterable( + [ + DatasetItem( + id=0, + subset="train", + media=TableRow(table=table, index=0), + annotations=[ + Caption("target:Two young, White males are outside near many bushes.") + ], + ) + ], + categories={}, + media_type=TableRow, + ) + + +@pytest.fixture +def fxt_dummy_tokenizer(): + def dummy_tokenizer(text): + return text.split() + + return dummy_tokenizer + + +@pytest.fixture +def data_iter(): + return [(1, "This is a sample text"), (2, "Another sample text")] + + +@pytest.fixture +def fxt_dummy_vocab(fxt_dummy_tokenizer, data_iter): + vocab = build_vocab_from_iterator( + map(fxt_dummy_tokenizer, (text for _, text in data_iter)), specials=[""] + ) + vocab.set_default_index(vocab[""]) + return vocab + + +@pytest.fixture +def fxt_tabular_fixture(fxt_dummy_tokenizer, fxt_dummy_vocab): + return {"target": {"input": "text"}, "tokenizer": fxt_dummy_tokenizer, "vocab": fxt_dummy_vocab} + + @pytest.mark.new @mark_requirement(Requirements.DATUM_GENERAL_REQ) class FrameworkConverterFactoryTest(TestCase): @@ -173,38 +260,49 @@ def test_create_converter_tf_importerror(self): @mark_requirement(Requirements.DATUM_GENERAL_REQ) class MultiframeworkConverterTest: @pytest.mark.parametrize( - "fxt_subset,fxt_task", + "fxt_dataset_type,fxt_subset,fxt_task", [ ( + "fxt_dataset", "train", "classification", ), ( + "fxt_dataset", "val", "multilabel_classification", ), ( + "fxt_dataset", "train", "detection", ), ( + "fxt_dataset", "val", "instance_segmentation", ), ( + "fxt_dataset", "train", "semantic_segmentation", ), + ("fxt_tabular_label_dataset", "train", "tabular"), ], ) - def test_multi_framework_dataset(self, fxt_dataset: Dataset, fxt_subset: str, fxt_task: str): + def test_multi_framework_dataset( + self, fxt_dataset_type: str, fxt_subset: str, fxt_task: str, request + ): + dataset = request.getfixturevalue(fxt_dataset_type) dm_multi_framework_dataset = _MultiFrameworkDataset( - dataset=fxt_dataset, subset=fxt_subset, task=fxt_task + dataset=dataset, subset=fxt_subset, task=fxt_task ) for idx in range(len(dm_multi_framework_dataset)): image, label = dm_multi_framework_dataset._gen_item(idx) - assert isinstance(image, np.ndarray) + if fxt_task == "tabular": + image = image() + assert isinstance(image, (np.ndarray, dict)) if fxt_task == "classification": assert isinstance(label, int) elif fxt_task == "multilabel_classification": @@ -213,6 +311,8 @@ def test_multi_framework_dataset(self, fxt_dataset: Dataset, fxt_subset: str, fx assert isinstance(label, list) if fxt_task == "semantic_segmentation": assert isinstance(label, np.ndarray) + elif fxt_task == "tabular": + assert isinstance(label, list) @pytest.mark.skipif(not TORCH_AVAILABLE, reason="PyTorch is not installed") @pytest.mark.parametrize( @@ -261,7 +361,6 @@ def test_can_convert_torch_framework( fxt_subset: str, fxt_task: str, fxt_convert_kwargs: Dict[str, Any], - request: pytest.FixtureRequest, ): multi_framework_dataset = FrameworkConverter(fxt_dataset, subset=fxt_subset, task=fxt_task) @@ -294,7 +393,12 @@ def test_can_convert_torch_framework( if ann.type == TASK_ANN_TYPE[fxt_task] ] label = np.sum(masks, axis=0, dtype=np.uint8) - + elif fxt_task == "tabular": + label = [ + ann.as_dict() + for ann in exp_item.annotations + if ann.type in TASK_ANN_TYPE[fxt_task] + ] if fxt_convert_kwargs.get("transform", None): actual = dm_torch_item[0].permute(1, 2, 0).mul(255.0).to(torch.uint8).numpy() assert np.array_equal(image, actual) @@ -374,6 +478,130 @@ def test_can_convert_torch_framework_detection(self): assert torch_ann["bbox"] == [x1, y1, x2 - x1, y2 - y1] assert torch_ann["iscrowd"] == dm_ann["attributes"]["is_crowd"] + @pytest.mark.skipif(not TORCH_AVAILABLE, reason="PyTorch is not installed") + def test_can_convert_torch_framework_tabular_label(self, fxt_tabular_label_dataset): + class IMDBDataset(Dataset): + def __init__(self, data_iter, vocab, transform=None): + self.data = list(data_iter) + self.vocab = vocab + self.transform = transform + self.tokenizer = get_tokenizer("basic_english") + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + label, text = self.data[idx] + token_ids = [self.vocab[token] for token in self.tokenizer(text)] + + if self.transform: + token_ids = self.transform(token_ids) + + return torch.tensor(token_ids, dtype=torch.long), torch.tensor( + label, dtype=torch.long + ) + + # Prepare data and tokenizer + # First item of IMDB + first_item = ( + 1, + "I rented I AM CURIOUS-YELLOW from my video store because of all the controversy that surrounded it when it was first released in 1967. I also heard that at first it was seized by U.S. customs if it ever tried to enter this country, therefore being a fan of films considered \"controversial\" I really had to see this for myself.

The plot is centered around a young Swedish drama student named Lena who wants to learn everything she can about life. In particular she wants to focus her attentions to making some sort of documentary on what the average Swede thought about certain political issues such as the Vietnam War and race issues in the United States. In between asking politicians and ordinary denizens of Stockholm about their opinions on politics, she has sex with her drama teacher, classmates, and married men.

What kills me about I AM CURIOUS-YELLOW is that 40 years ago, this was considered pornographic. Really, the sex and nudity scenes are few and far between, even then it's not shot like some cheaply made porno. While my countrymen mind find it shocking, in reality sex and nudity are a major staple in Swedish cinema. Even Ingmar Bergman, arguably their answer to good old boy John Ford, had sex scenes in his films.

I do commend the filmmakers for the fact that any sex shown in the film is shown for artistic purposes rather than just to shock people and make money to be shown in pornographic theaters in America. I AM CURIOUS-YELLOW is a good film for anyone wanting to study the meat and potatoes (no pun intended) of Swedish cinema. But really, this film doesn't have much of a plot.", + ) + tokenizer = get_tokenizer("basic_english") + + # Build vocabulary + vocab = build_vocab_from_iterator([tokenizer(first_item[1])], specials=[""]) + vocab.set_default_index(vocab[""]) + + # Create torch dataset + torch_dataset = IMDBDataset(iter([first_item]), vocab) + + # Convert to dm_torch_dataset + dm_dataset = fxt_tabular_label_dataset + multi_framework_dataset = FrameworkConverter(dm_dataset, subset="train", task="tabular") + dm_torch_dataset = multi_framework_dataset.to_framework( + framework="torch", target={"input": "text"}, tokenizer=tokenizer, vocab=vocab + ) + + # Verify equality of items in torch_dataset and dm_torch_dataset + label_indices = dm_dataset.categories().get(AnnotationType.label)._indices + torch_item = torch_dataset[0] + dm_item = dm_torch_dataset[0] + assert torch.equal(torch_item[0], dm_item[0]), "Token IDs do not match" + + # Extract and compare labels + torch_item_label = str(torch_item[1].item()) + dm_item_label = list(label_indices.keys())[list(label_indices.values()).index(0)].split( + ":" + )[-1] + assert torch_item_label == dm_item_label, "Labels do not match" + + @pytest.mark.skipif(not TORCH_AVAILABLE, reason="PyTorch is not installed") + def test_can_convert_torch_framework_tabular_caption(self, fxt_tabular_caption_dataset): + class Multi30kDataset(Dataset): + def __init__(self, dataset, src_tokenizer, tgt_tokenizer, src_vocab, tgt_vocab): + self.dataset = list(dataset) + self.src_tokenizer = src_tokenizer + self.tgt_tokenizer = tgt_tokenizer + self.src_vocab = src_vocab + self.tgt_vocab = tgt_vocab + + def __len__(self): + return len(self.dataset) + + def _data_process(self, text, tokenizer, vocab): + tokens = tokenizer(text) + token_ids = [vocab[token] for token in tokens] + return torch.tensor(token_ids, dtype=torch.long) + + def __getitem__(self, idx): + src, tgt = self.dataset[idx] + src_tensor = self._data_process(src, self.src_tokenizer, self.src_vocab) + tgt_tensor = self._data_process(tgt, self.tgt_tokenizer, self.tgt_vocab) + return src_tensor, tgt_tensor + + # Prepare data and tokenizer + # First item of Multi30k + first_item = ( + "Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.", + "Two young, White males are outside near many bushes.", + ) + + dummy_tokenizer = str.split + + def build_single_vocab(item, tokenizer, specials): + tokens = tokenizer(item) + vocab = build_vocab_from_iterator([tokens], specials=specials) + vocab.set_default_index(vocab[""]) + return vocab + + # Build vocabularies + specials = ["", "", "", ""] + src_vocab = build_single_vocab(first_item[0], dummy_tokenizer, specials) + tgt_vocab = build_single_vocab(first_item[1], dummy_tokenizer, specials) + + # Create torch dataset + torch_dataset = Multi30kDataset( + iter([first_item]), dummy_tokenizer, dummy_tokenizer, src_vocab, tgt_vocab + ) + + # Convert to dm_torch_dataset + dm_dataset = fxt_tabular_caption_dataset + multi_framework_dataset = FrameworkConverter(dm_dataset, subset="train", task="tabular") + dm_torch_dataset = multi_framework_dataset.to_framework( + framework="torch", + target={"input": "source", "output": "target"}, + tokenizer=(dummy_tokenizer, dummy_tokenizer), + vocab=(src_vocab, tgt_vocab), + ) + + # Verify equality of items in torch_dataset and dm_torch_dataset + torch_item = torch_dataset[0] + dm_item = dm_torch_dataset[0] + + assert torch.equal(torch_item[0], dm_item[0]), "Token IDs for de do not match" + assert torch.equal(torch_item[1], dm_item[1]), "Token IDs for en do not match" + @pytest.mark.skipif(not TF_AVAILABLE, reason="Tensorflow is not installed") @pytest.mark.parametrize( "fxt_subset,fxt_task,fxt_convert_kwargs",