diff --git a/.github/workflows/requirements-dev.txt b/.github/workflows/requirements-dev.txt new file mode 100644 index 0000000..e0733ed --- /dev/null +++ b/.github/workflows/requirements-dev.txt @@ -0,0 +1,5 @@ +# Copyright (C) 2023 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: MIT + +pre-commit diff --git a/.github/workflows/static_checks.yaml b/.github/workflows/static_checks.yaml new file mode 100644 index 0000000..84aa47c --- /dev/null +++ b/.github/workflows/static_checks.yaml @@ -0,0 +1,76 @@ +# Copyright (C) 2023 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: MIT + +name: Static code checks + +on: # yamllint disable-line rule:truthy + pull_request: + push: + branches: + - '**' + tags-ignore: + - '**' + +env: + LICENSE: MIT + FETCH_DEPTH: 1 + FULL_HISTORY: 0 + SKIP_WORD_PRESENCE_CHECK: 0 + +jobs: + static-code-check: + if: endsWith(github.event.repository.name, 'private') + + name: Run static code checks + runs-on: ubuntu-latest + defaults: + run: + shell: bash -l {0} + + steps: + - name: Setup history + if: github.ref == 'refs/heads/oss' + run: | + echo "FETCH_DEPTH=0" >> $GITHUB_ENV + echo "FULL_HISTORY=1" >> $GITHUB_ENV + + - name: Setup version + if: github.ref == 'refs/heads/melco' + run: | + echo "SKIP_WORD_PRESENCE_CHECK=1" >> $GITHUB_ENV + + - name: Check out code + uses: actions/checkout@v3 + with: + fetch-depth: ${{ env.FETCH_DEPTH }} # '0' to check full history + + - name: Set up environment + run: git config user.email github-bot@merl.com + + - name: Set up python + uses: actions/setup-python@v4 + with: + python-version: 3 + cache: 'pip' + cache-dependency-path: '.github/workflows/requirements-dev.txt' + + - name: Install python packages + run: pip install -r .github/workflows/requirements-dev.txt + + - name: Ensure lint and pre-commit steps have been run + uses: pre-commit/action@v3.0.0 + + - name: Check files + uses: merl-oss-private/merl-file-check-action@v1 + with: + license: ${{ env.LICENSE }} + full-history: ${{ env.FULL_HISTORY }} # If true, use fetch-depth 0 above + skip-word-presence-check: ${{ env.SKIP_WORD_PRESENCE_CHECK }} + + - name: Check license compatibility + if: github.ref != 'refs/heads/melco' + uses: merl-oss-private/merl_license_compatibility_checker@v1 + with: + input-filename: requirements.txt + license: ${{ env.LICENSE }} diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..30e783d --- /dev/null +++ b/.gitignore @@ -0,0 +1,168 @@ +# Copyright (C) 2023 Mitsubishi Electric Research Laboratories (MERL). +# +# SPDX-License-Identifier: MIT + +# Python .gitignore from https://github.com/github/gitignore/blob/main/Python.gitignore +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +# Custom ignores +.DS_Store diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..c14f6b3 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,64 @@ +# Copyright (C) 2023 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: MIT +# +# Pre-commit configuration. See https://pre-commit.com + +default_language_version: + python: python3 + +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.4.0 + hooks: + - id: end-of-file-fixer + - id: trailing-whitespace + - id: check-yaml + - id: check-added-large-files + args: ['--maxkb=5000'] + + - repo: https://gitlab.com/bmares/check-json5 + rev: v1.0.0 + hooks: + - id: check-json5 + + - repo: https://github.com/homebysix/pre-commit-macadmin + rev: v1.12.3 + hooks: + - id: check-git-config-email + args: ['--domains', 'merl.com'] + + - repo: https://github.com/psf/black + rev: 22.12.0 + hooks: + - id: black + args: + - --line-length=120 + + - repo: https://github.com/pycqa/isort + rev: 5.12.0 + hooks: + - id: isort + args: ["--profile", "black", "--filter-files", "--line-length", "120", "--skip-gitignore"] + + # Uncomment to use pyupgrade (https://github.com/asottile/pyupgrade) to automatically upgrade syntax for newer python + # - repo: https://github.com/asottile/pyupgrade + # rev: v3.3.1 + # hooks: + # - id: pyupgrade + + # To stop flake8 error from causing a failure, use --exit-zero. By default, pre-commit will not show the warnings, + # so use verbose: true to see them. + - repo: https://github.com/pycqa/flake8 + rev: 5.0.4 + hooks: + - id: flake8 + # Black compatibility, Eradicate options + args: ["--max-line-length=120", "--extend-ignore=E203", + "--eradicate-whitelist-extend", "eradicate:\\s*no", + "--exit-zero"] + verbose: true + additional_dependencies: [ + # https://github.com/myint/eradicate, https://github.com/wemake-services/flake8-eradicate + "flake8-eradicate" + ] diff --git a/.reuse/dep5 b/.reuse/dep5 new file mode 100644 index 0000000..8232133 --- /dev/null +++ b/.reuse/dep5 @@ -0,0 +1,9 @@ +Format: https://www.debian.org/doc/packaging-manuals/copyright-format/1.0/ + +Files: .vscode/* +Copyright: 2023 Mitsubishi Electric Research Laboratories (MERL) +License: MIT + +Files: images/smart101-banner2.png dataset/*.csv dataset/*.txt +Copyright: 2023 Mitsubishi Electric Research Laboratories (MERL) +License: MIT diff --git a/.vscode/README_VSCode.md b/.vscode/README_VSCode.md new file mode 100644 index 0000000..8d8b201 --- /dev/null +++ b/.vscode/README_VSCode.md @@ -0,0 +1,14 @@ + +# VS Code recommended extensions and settings + +These files provide recommended extensions and workspace settings for VS Code for python development. The recommended extensions are: + +* [Python](https://marketplace.visualstudio.com/items?itemName=ms-python.python"): Official python extension from Microsoft +* [Python Type Hint](https://marketplace.visualstudio.com/items?itemName=njqdev.vscode-python-typehint): Type hint completion for Python +* [autoDocstring](https://marketplace.visualstudio.com/items?itemName=njpwerner.autodocstring): Generates python docstrings automatically + +If these extensions are not already globally installed, they will be recommended to you for installation when you open the project in VS Code. diff --git a/.vscode/extensions.json b/.vscode/extensions.json new file mode 100644 index 0000000..2d42587 --- /dev/null +++ b/.vscode/extensions.json @@ -0,0 +1,7 @@ +{ + "recommendations": [ + "ms-python.python", + "njqdev.vscode-python-typehint", + "njpwerner.autodocstring" + ] +} diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..b5520a9 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,32 @@ +{ + "editor.rulers": [ + 120 + ], + "[python]": { + "editor.tabSize": 4 + }, + "[markdown]": { + "editor.wordWrap": "bounded", + "editor.wordWrapColumn": 120 + }, + "files.eol": "\n", + "files.insertFinalNewline": true, + "files.trimFinalNewlines": true, + "files.trimTrailingWhitespace": true, + "editor.formatOnSave": true, + "python.formatting.provider": "black", + "python.formatting.blackArgs": [ + "--line-length=120" + ], + "python.linting.flake8Enabled": true, + "python.linting.enabled": true, + "python.linting.flake8Args": [ + "--max-line-length=120", + "--extend-ignore=E203" + ], + "python.testing.pytestArgs": [ + "tests" + ], + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true +} diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..ea8e37e --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,10 @@ + + +# Contributing + +Sorry, but we do not currently accept contributions in the form of pull requests to this repository. However, you are +welcome to post issues (bug reports, feature requests, questions, etc). diff --git a/DEPENDENCIES.md b/DEPENDENCIES.md new file mode 100644 index 0000000..d27caab --- /dev/null +++ b/DEPENDENCIES.md @@ -0,0 +1,32 @@ + + + +# Dependent pre-trained models +Our code used the following publicly available pre-trained models for visual processing, language modeling, and vision-and-language reasoning. We also provide below the licenses associated with these models, and the download links that we used. Note that we do not use or modify the code behind these pre-trained models, and one may use other implementations of these models, if needed. + +| Name | License | Link | +|:-------------------|:---------------|------| +| BERT/HuggingFace | Apache-2.0 | https://huggingface.co/docs/transformers/model_doc/bert | +| GPT2/HuggingFace | MIT | https://huggingface.co/gpt2 | +| MAE/HuggingFace | Apache-2.0 | https://huggingface.co/facebook/vit-mae-large | +| CrossTransformer | MIT | https://github.com/lucidrains/vit-pytorch | +| CLIP/OpenAI | MIT | https://github.com/openai/CLIP | +| FLAVA/HuggingFace | BSD-3-Clause | https://huggingface.co/facebook/flava-full | + +We also used the following models that are part of the torchvision toolbox of PyTorch (released under `BSD-3-Clause`) +| Name | License | Link | +|:-------------------|:---------------|------| +| AlexNet/VGG | BSD-3-Clause | https://pytorch.org/vision/stable/models.html | +| ResNet-50 | BSD-3-Clause | https://pytorch.org/vision/main/models/generated/torchvision.models.resnet50.html | +| ResNet-18 | BSD-3-Clause | https://pytorch.org/vision/main/models/generated/torchvision.models.resnet18.html | +| ViT-16 | BSD-3-Clause | https://pytorch.org/vision/main/models/vision_transformer.html | +| Swin_b, Swin_t | BSD-3-Clause | https://pytorch.org/vision/main/models/swin_transformer.html | + +We used GloVe language embeddings from the torchtext toolbox of PyTorch. +| Name | License | Link | +|:-------------------|:---------------|------| +| GloVe from TorchText | BSD-3-Clause | https://pypi.org/project/torchtext/ | diff --git a/LICENSE.md b/LICENSE.md new file mode 100644 index 0000000..44ed5d2 --- /dev/null +++ b/LICENSE.md @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Mitsubishi Electric Research Laboratories (MERL) + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..cc7f88c --- /dev/null +++ b/README.md @@ -0,0 +1,118 @@ + + + + +![](./images/smart101-banner2.png) + +## Overview + +Recent times have witnessed an increasing number of applications of deep neural networks towards solving tasks that require superior cognitive abilities, e.g., playing Go, generating art, ChatGPT, etc. Such a dramatic progress raises the question: how generalizable are neural networks in solving problems that demand broad skills? To answer this question, we propose SMART: a Simple Multimodal Algorithmic Reasoning Task (and the associated SMART-101 dataset) for evaluating the abstraction, deduction, and generalization abilities of neural networks in solving visuo-linguistic puzzles designed specifically for children of younger age (6--8). Our dataset consists of 101 unique puzzles; each puzzle comprises a picture and a question, and their solution needs a mix of several elementary skills, including pattern recognition, algebra, and spatial reasoning, among others. To train deep neural networks, we programmatically augment each puzzle to 2,000 new instances; each instance varied in appearance, associated natural language question, and its solution. To foster research and make progress in the quest for artificial general intelligence, we are publicly releasing our PyTorch implementation. + +This repository contains training and testing code reported in the CVPR 2023 paper Are Deep Neural Networks SMARTer than Second Graders? ***by Anoop Cherian, Kuan-Chuan Peng, Suhas Lohit, Kevin A. Smith, and Joshua B. Tenenbaum***. + +## Code Setup +``` + conda create --name SMART python=3.9 + conda activate SMART + pip install -r requirements.txt + pip install git+https://github.com/openai/CLIP.git +``` + +## SMART-101 Data Setup +``` + # Download the SMART-101 dataset from https://zenodo.org/record/7775984 + # To download you can use: wget https://zenodo.org/record/7775984/files/SMART101-release-v1.zip?download=1 -P + # cd + # unzip SMART101-release-v1.zip -d >/dev/null +``` + +After the unzip, `` will have the directory structure: `/SMART101-release-v1/SMART101-Data` -- this is the location where the 101 puzzles are stored. + +***Known Issue:*** You will receive an error on `/SMART101-release-v1/SMART101-Data/62/puzzle_63.csv` when running the code. To resolve this problem, you will need to rename `/SMART101-release-v1/SMART101-Data/62/puzzle_63.csv` to `/SMART101-release-v1/SMART101-Data/62/puzzle_62.csv`. + +## Train and Test Command Lines +Our implementation provides a variety of options to train and test diverse state-of-the-art neural network models. Please see `main.py` for all options. Below, we list a few command lines to get started and the arguments that you may change for learning using other backbones. All the backbones are downloaded or implementd in `net.py` or `net_clip.py` (where the latter is only used for `CLIP` feature extraction). + +### Command lines +***To train a ResNet-50 + BERT backbone*** while also fine-tuning the ResNet-50 model for 100 epochs with a classifier head and puzzle specific output heads (the best option used in our paper), use the command line: + +``` +python main.py --model_name resnet50 --num_workers 8 --num_epochs 100 --loss_type classifier --batch_size 64 --log_freq 10 --train_backbone --word_embed bert --data_root /SMART101-release-v1/SMART101-Data/ --puzzles all --split_type standard +``` +Here, `--train_backbone` is for training the image backbone model, `--puzzles all` says to train on all puzzles, and `--split_type standard` specifies `instance split` to use. The above command will produce the results in a folder `/results/` where `` is a random seed number for a "run" of the above code (e.g., you may specify `--seed 1234` in the commandline above and the default `` is `./data/v2/`). The best validation model associated with this run will be saved in: `/checkpoints/ckpt__.pth`. + +***To evaluate a checkpoint*** (without training/finetuning) for a seed (that either you specified in the training code above or is randomly created when the seed is not given), use the command: + +``` +python main.py --model_name resnet50 --num_workers 8 --num_epochs 100 --loss_type classifier --batch_size 64 --log_freq 10 --train_backbone --word_embed bert --data_root /SMART101-release-v1/SMART101-Data/ --puzzles all --test --seed 1234 +``` + +When the training and evaluation are done, the code reports the Solution accuracy $S_{acc}$, Option accuracy $O_{acc}$, and error variance $err$, where $err$ is the average weighted relative prediction error, which is the average $\ell_1$ distance between the predicted and the ground truth answers, each puzzle error inversely weighted by the cardinality of its answer range. We also report the category-wise accuracy among the eight skill sets as well as, the average accuracy over all instances of a root puzzle used in the evaluation. + +### Possible Arguments: +* --split_type: {'standard', 'exclude', 'fewshot', 'puzzle'} corresponding to instance split (IS), answer split (AS), few-shot split (FS), and puzzle split (PS), respectively. +* --model_name: {'resnet50', 'resnet18', 'mae', 'clip', 'flava', 'vit', 'swin_b', 'swin_t', 'alexnet', 'vgg', 'cross_transformer'} +* --word_embed: {'standard', 'bert', 'gpt', 'glove'}, where `standard` uses learned embeddings. +* --loss_type: {'classifier', 'regression'} + +You may also use a contrastive learning image backbone (e.g., SimSiam). For this, use the option `--model_name resnet50 --pretrained ./` for a `resnet50` based pretrained model. + +### Other Useful Commandlines and tips: +* --baselines: will print the baseline (greedy and uniform) accuracies on all puzzles in train, val, and test. +* --split_type fewshot --fsK : will specify the number of fewshow samples to use +* --data_tot : speficies the total number of puzzle instances to use for each root puzzle (you may ask the method to use less than 2000 instances) +* --puzzles id1,id2,...: You can specify a comma separated list of puzzles to use to do the train and test on only those puzzles (default is --puzzles 'all'). +* See globalvars.py to see more hyperparamters (not reachable via commandlines), including the set of puzzles used in "puzzle/fewshot split". + +### Outputs: +The code generates outputs in the folder ``. There are many intermediate data files and performance statistics produced, namely: +* `/vocab_puzzle_all.pkl`: the word vocabulary used in the learned embeddings +* `/results//acc_perf_scores.png`: puzzle-wise $S_{acc}$ performances +* `/results//opt_acc_perf_scores.png`: puzzle-wise $O_{acc}$ performances +* `/results//cmd_line.txt`: command line used in the current run + +Please refer to main.py to see other arguments to the code. + +### Trained models: +The best validation model associated with a run is saved at `/checkpoints/ckpt__.pth`. + +## Citation + +If you use this code, please cite the following paper: + +```BibTeX +@article{cherian2022deep, + title={Are Deep Neural Networks SMARTer than Second Graders?}, + author={Cherian, Anoop and Peng, Kuan-Chuan and Lohit, Suhas and Smith, Kevin and Tenenbaum, Joshua B}, + journal={arXiv preprint arXiv:2212.09993}, + year={2022} +} +``` + +## Contact + +Anoop Cherian: cherian@merl.com, Kuan-Chuan Peng: kpeng@merl.com, or Suhas Lohit: slohit@merl.com. + +## Contributing + +See [CONTRIBUTING.md](CONTRIBUTING.md) for our policy on contributions. + +## License + +Released under `MIT` license, as found in the [LICENSE.md](LICENSE.md) file. See [DEPENDENCIES.md](DEPENDENCIES.md) for details on publicly available pre-trained models that our software uses, their licenses, and the download links. + +All files: + +``` +Copyright (C) 2023 Mitsubishi Electric Research Laboratories (MERL). + +SPDX-License-Identifier: MIT +``` diff --git a/baselines.py b/baselines.py new file mode 100644 index 0000000..0137b0e --- /dev/null +++ b/baselines.py @@ -0,0 +1,154 @@ +#!/usr/bin/env python3 +# Copyright (c) 2023 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: MIT +# +import os + +import matplotlib.pyplot as plt +import numpy as np + +import globvars as gv +import utils + + +def plot_baseline_perf(args, baseline_key, key, tot, split, suffix): + plt.figure(figsize=(30, 4)) + ax = plt.gca() + bpid = np.array(baseline_key) + bclassids = np.arange(gv.NUM_CLASSES_PER_PUZZLE[key]) + x = np.histogram(bpid, np.arange(gv.NUM_CLASSES_PER_PUZZLE[key] + 1))[0] + x = x / x.sum() + ax.bar(bclassids, x) + ax.set_xticks(bclassids) + plt.savefig( + "%s/stats/ans_distr/%s/%s/ans_distr_%s_%d_%s_%s.png" % (args.save_root, suffix, split, key, tot, split, suffix) + ) + plt.close() + return x + + +def get_baseline_performance(args, qa_info, split, tot, log=False): + + topK = lambda x: np.sort(x)[-2:].sum() / 2.0 + baseline = {} + baseline_opts = {} + for t in range(len(qa_info)): + pid = qa_info[t]["puzzle_id"] + if int(pid) not in gv.SEQ_PUZZLES and int(pid) != 58: + if pid not in baseline: + baseline[pid] = [] + baseline_opts[pid] = [] + baseline[pid].append(qa_info[t]["AnswerValue"]) + baseline_opts[pid].append(ord(qa_info[t]["Answer"]) - ord("A")) + if not os.path.exists(os.path.join(args.save_root, "stats/ans_distr/sacc/train/")): + os.makedirs(os.path.join(args.save_root, "stats/ans_distr/sacc/train/")) + os.makedirs(os.path.join(args.save_root, "stats/ans_distr/sacc/val")) + os.makedirs(os.path.join(args.save_root, "stats/ans_distr/sacc/test")) + os.makedirs(os.path.join(args.save_root, "stats/ans_distr/oacc/train")) + os.makedirs(os.path.join(args.save_root, "stats/ans_distr/oacc/val")) + os.makedirs(os.path.join(args.save_root, "stats/ans_distr/oacc/test")) + tot_baseline_sacc_greedy = 0.0 + tot_baseline_sacc_bestK = 0.0 + tot_rand_sacc = 0.0 + tot_baseline_oacc_greedy = 0.0 + tot_baseline_oacc_bestK = 0.0 + tot_rand_oacc = 0.0 + baseline_sacc = {} + baseline_oacc = {} + baseline_rand_sacc = {} + overall_baseline_sacc = {} + overall_baseline_oacc = {} + overall_baseline_rand = {} + for key in baseline.keys(): + x = plot_baseline_perf(args, baseline[key], key, tot, split, "sacc") + baseline_sacc[key] = (x.argmax(), x.max(), topK(x), len(x)) # the class and the value. + x = plot_baseline_perf(args, baseline_opts[key], key, tot, split, "oacc") + baseline_oacc[key] = (x.argmax(), x.max(), topK(x), len(x)) # the class and the value. + baseline_rand_sacc[key] = 1 / gv.NUM_CLASSES_PER_PUZZLE[key] + + tot_baseline_sacc_greedy += baseline_sacc[key][1] + tot_baseline_sacc_bestK += baseline_sacc[key][2] + tot_rand_sacc += 1 / gv.NUM_CLASSES_PER_PUZZLE[key] + + tot_baseline_oacc_greedy += baseline_oacc[key][1] + tot_baseline_oacc_bestK += baseline_oacc[key][2] + tot_rand_oacc += 1 / 5.0 + + if True: # log: + print( + "baseline %s class = %d freq = %f bestK_acc=%f percent num_classes=%d" + % (key, baseline_sacc[key][0], baseline_sacc[key][1], baseline_sacc[key][2], baseline_sacc[key][3]) + ) + overall_baseline_sacc[key] = baseline_sacc[key][1] + overall_baseline_oacc[key] = baseline_oacc[key][1] + overall_baseline_rand[key] = baseline_rand_sacc[key] + print("\n\n") + tot_keys = len(baseline.keys()) + print( + "overall baseline (%d puzzles) Greedy: top-1 sacc/oacc = %0.4f/%0.4f Greedy: top-K sacc/oacc=%0.4f/%0.4f random sacc=%0.4f " + % ( + len(baseline.keys()), + tot_baseline_sacc_greedy / tot_keys, + tot_baseline_oacc_greedy / tot_keys, + tot_baseline_sacc_bestK / tot_keys, + tot_baseline_oacc_bestK / tot_keys, + tot_rand_sacc / tot_keys, + ) + ) + + base_sacc_list = np.zeros( + gv.num_puzzles + 1, + ) + base_oacc_list = np.zeros( + gv.num_puzzles + 1, + ) + rand_sacc_list = np.zeros( + gv.num_puzzles + 1, + ) + for key in baseline.keys(): + base_sacc_list[int(key)] = overall_baseline_sacc[key] + base_oacc_list[int(key)] = overall_baseline_oacc[key] + rand_sacc_list[int(key)] = overall_baseline_rand[key] + + # print category-wise performances. # copied from utils. + puzzles = utils.read_dataset_info(gv.VILPS_DATASET_INFO_FILE) + cls_mean = lambda x, idx: np.array([x[int(ii)] for ii in idx]).mean() + get_int_set = lambda x: set([int(ii) for ii in x]) + class_avg_perf = {} + classes = ["counting", "math", "logic", "path", "algebra", "measure", "spatial", "pattern"] + print(classes) + print("Greedy %s" % (split)) + for kk in classes: + idx_list = np.array(list(get_int_set(puzzles[kk]).intersection(get_int_set(baseline.keys())))) + class_avg_perf[kk] = ( + cls_mean(base_sacc_list, idx_list), + cls_mean(base_oacc_list, idx_list), + cls_mean(rand_sacc_list, idx_list), + ) + print("%0.3f/%0.3f & " % (class_avg_perf[kk][0], class_avg_perf[kk][1]), end=" ") + print("\nUniform %s" % (split)) + for kk in classes: + print("%0.3f/%0.3f & " % (class_avg_perf[kk][2], 0.2), end=" ") + print("\n\n") + + plt.figure(figsize=(30, 4)) + ax = plt.gca() + ax.bar(np.arange(1, gv.num_puzzles + 1), 100.0 * base_sacc_list[1:]) + ax.set_xticks(np.arange(1, gv.num_puzzles + 1)) # , [str(i) for i in np.arange(1,num_puzzles+1)]) + plt.savefig( + "%s/stats/ans_distr/%s/%s/base_sacc_perf_with_greedy_choices_%d.png" % (args.save_root, "sacc", split, tot) + ) + plt.close() + + plt.figure(figsize=(30, 4)) + ax = plt.gca() + ax.bar(np.arange(1, gv.num_puzzles + 1), 100.0 * base_oacc_list[1:]) + ax.set_xticks(np.arange(1, gv.num_puzzles + 1)) # , [str(i) for i in np.arange(1,num_puzzles+1)]) + plt.savefig( + "%s/stats/ans_distr/%s/%s/base_oacc_perf_with_greedy_choices_%d.png" % (args.save_root, "oacc", split, tot) + ) + plt.close() + + np.save("%s/stats/baseline_%d_%s.npy" % (args.save_root, tot, split), [baseline_sacc, baseline_oacc]) + return baseline_sacc, baseline_oacc diff --git a/build_vocab.py b/build_vocab.py new file mode 100644 index 0000000..9c48d5f --- /dev/null +++ b/build_vocab.py @@ -0,0 +1,126 @@ +#!/usr/bin/env python3 +# Copyright (c) 2023 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: MIT +# +import nltk + +# make sure nltk works fine. +try: + nltk.data.find("tokenizers/punkt") +except LookupError: + print("downloading nltk/punkt tokenizer") + nltk.download("punkt") + +import argparse +import glob +import os +import pickle +from collections import Counter + +from utils import save_file + + +class Vocabulary(object): + """Simple vocabulary wrapper.""" + + def __init__(self): + self.word2idx = {} + self.idx2word = {} + self.idx = 0 + + def add_word(self, word): + if not word in self.word2idx: + self.word2idx[word] = self.idx + self.idx2word[self.idx] = word + self.idx += 1 + + def __call__(self, word): + if not word in self.word2idx: + return self.word2idx[""] + return self.word2idx[word] + + def __len__(self): + return len(self.word2idx) + + +def build_vocab(text_rows, threshold): + """Build a simple vocabulary wrapper.""" + + print("total QA pairs", len(text_rows)) + counter = Counter() + + for text in text_rows: + tokens = nltk.tokenize.word_tokenize(text.lower()) + counter.update(tokens) + + counter = sorted(counter.items(), key=lambda item: item[1], reverse=True) + save_file(dict(counter), "dataset/VideoQA/word_count.json") + # If the word frequency is less than 'threshold', then the word is discarded. + words = [item[0] for item in counter if item[1] >= threshold] + print(len(words)) + # Create a vocab wrapper and add some special tokens. + vocab = Vocabulary() + vocab.add_word("") + vocab.add_word("") + vocab.add_word("") + vocab.add_word("") + + # Add the words to the vocabulary. + for i, word in enumerate(words): + vocab.add_word(word) + + return vocab + + +def read_csv(csvfilename): + import csv + + qa_info = [] + with open(csvfilename, newline="") as csvfile: + datareader = csv.DictReader(csvfile) + for row in datareader: + qa_info.append(row["Question"]) + qa_info.append(row["A"] + " " + row["B"] + " " + row["C"] + " " + row["D"] + " " + row["E"]) + return qa_info + + +def process_text_for_puzzle(args): + vocab_path = os.path.join(args.save_root, "vocab_puzzle_" + args.puzzle_ids_str + ".pkl") + if os.path.exists(vocab_path): + print("loading vocab %s" % (vocab_path)) + with open(vocab_path, "rb") as f: + vocab = pickle.load(f) + else: + text_rows = [] + for puzzle_id in args.puzzle_ids: + print("reading puzzle %s" % (puzzle_id)) + text_files = glob.glob(os.path.join(args.data_root, str(puzzle_id), "puzzle_%s.csv" % (puzzle_id))) + for t in range(len(text_files)): + rows = read_csv(text_files[t]) + text_rows = text_rows + rows + vocab = build_vocab(text_rows, threshold=3) + with open(vocab_path, "wb") as f: + pickle.dump(vocab, f) + print("generating new vocab for %s: num_words=%d" % (args.puzzle_ids_str, len(vocab))) + return vocab + + +def main(args): + vocab = build_vocab(args.caption_path, args.threshold) + vocab_path = args.vocab_path + print("Total vocabulary size: {}".format(len(vocab))) + print("Saved the vocabulary wrapper to '{}'".format(vocab_path)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--anno_path", type=str, default="dataset/nextqa/train.csv", help="path for train annotation file" + ) + parser.add_argument( + "--vocab_path", type=str, default="dataset/VideoQA/vocab.pkl", help="path for saving vocabulary wrapper" + ) + parser.add_argument("--threshold", type=int, default=1, help="minimum word count threshold") + args = parser.parse_args() + main(args) diff --git a/data_loader.py b/data_loader.py new file mode 100644 index 0000000..80573f0 --- /dev/null +++ b/data_loader.py @@ -0,0 +1,335 @@ +#!/usr/bin/env python3 +# Copyright (c) 2023 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: MIT +# +import os +import warnings + +import numpy as np +import torch + +warnings.filterwarnings("ignore") +import pdb +import pickle + +import nltk +from PIL import Image +from torch.utils.data import Dataset +from torchvision.transforms import Compose, Normalize, Resize, ToTensor + +import baselines +import globvars as gv +import utils + + +class SMART_Data(Dataset): + def __init__(self, args): + vocab_path = args.vocab_path + self.max_qlen = 110 + self.max_olen = 4 # max option length + self.use_word_embed = False + self.word_embed = None + self.im_side = 224 + self.preprocess = args.preprocess + self.no_question = args.no_question + self.no_image = args.no_image + + with open(vocab_path, "rb") as f: + self.vocab = pickle.load(f) + print("vocabulary size = %d" % (len(self.vocab))) + + if args.preprocess is None: # VL models, will do preprocess later. + self.transform = Compose( + [ + Resize(224), # if the images are of higher resolution. we work with pre-resized 224x224 images. + # RandomCrop(224), + ToTensor(), + Normalize(torch.Tensor([0.5]), torch.Tensor([0.5])), + ] + ) + elif args.model_name in ["flava", "mae"]: # this will do feature extractin later. + self.transform = Compose( + [ + Resize(300), + ToTensor(), + ] + ) + else: + self.transform = args.preprocess + + def apply_transform(self, im_path): + if self.no_image: # create a dummy image. + im = Image.fromarray((np.random.rand(self.im_side, self.im_side, 3) * 255).astype("uint8")) + else: + im = Image.open(im_path).convert("RGB") + return self.transform(im) + + def quest_encode(self, question): + tokens = nltk.tokenize.word_tokenize(question.lower()) + q_enc = np.zeros((self.max_qlen,), dtype="long") + if not self.no_question: + enc_tokens = ( + [self.vocab("")] + [self.vocab(tokens[t]) for t in range(len(tokens))] + [self.vocab("")] + ) + q_enc[: min(self.max_qlen, len(enc_tokens))] = np.array(enc_tokens) + return q_enc + + def ans_encode(self, answer): + return ord(answer) - ord("A") + + def opts_encode(self, opts, key): + opts = opts.lower() + tokens = nltk.tokenize.word_tokenize(opts) + enc_tokens = [self.vocab(tokens[t]) for t in range(len(tokens))] + opt_enc = np.zeros((self.max_olen,), dtype="long") + opt_enc[: min(self.max_olen, len(enc_tokens))] = np.array(enc_tokens) + return opt_enc + + def split_fewshot_puzzles(self, puzzle_ids, split_ratio, split_name, split_type): + if split_name == "train": + split_pids = self.split_puzzles(puzzle_ids, split_ratio, "train", split_type) + other_pids = self.split_puzzles(puzzle_ids, split_ratio, "test", split_type) + other_pids = other_pids + self.split_puzzles(puzzle_ids, split_ratio, "val", split_type) + return split_pids, other_pids + else: + split_pids = self.split_puzzles(puzzle_ids, split_ratio, split_name, split_type) + other_pids = None + return split_pids, other_pids + + def split_puzzles(self, puzzle_ids, split_ratio, split_name, split_type): + if split_type == "puzzle" or split_type == "fewshot": + if split_name == "train": + val_test = gv.PS_VAL_IDX + gv.PS_TEST_IDX + val_test = set([str(ii) for ii in val_test]) + puzzle_ids = list(set(puzzle_ids).difference(val_test)) + print("number of train puzzles = %d" % (len(puzzle_ids))) + elif split_name == "val": + puzzle_ids = [str(ii) for ii in gv.PS_VAL_IDX] + print("number of val puzzles = %d" % (len(puzzle_ids))) + else: + puzzle_ids = [str(ii) for ii in gv.PS_TEST_IDX] + print("number of test puzzles = %d" % (len(puzzle_ids))) + else: + splits = np.array([int(spl) for spl in split_ratio.split(":")]).cumsum() + n = len(puzzle_ids) + if split_name == "train": + st = 0 + en = int(np.floor(n * splits[0] / 100.0)) + puzzle_ids = puzzle_ids[st:en] + elif split_name == "val": + st = int(np.ceil(n * splits[0] / 100.0)) + en = int(np.floor(n * splits[1] / 100.0)) + puzzle_ids = puzzle_ids[st:en] + else: + st = int(np.ceil(n * splits[1] / 100.0)) + puzzle_ids = puzzle_ids[st:] + print("puzzles for %s =" % (split_name)) + print(puzzle_ids) + return puzzle_ids + + def split_data(self, info, split_ratio, split_name, split_type="standard"): + """ + split_type=standard is to use the split_ratio in the instance order + split_type=exclude is to exclude answers from the split, e.g., train on all answers except say 1, and test 1 + split_type=puzzle is to split the puzzles into the respective ratios. so we don't have to do anything here. + """ + if split_type == "standard" or split_type == "puzzle" or split_type == "fewshot": + splits = np.array([int(spl) for spl in split_ratio.split(":")]).cumsum() + n = len(info) + if split_name == "train": + st = 0 + en = int(np.floor(n * splits[0] / 100.0)) + info = info[st:en] + elif split_name == "val": + st = int(np.ceil(n * splits[0] / 100.0)) + en = int(np.floor(n * splits[1] / 100.0)) + info = info[st:en] + else: + st = int(np.ceil(n * splits[1] / 100.0)) + info = info[st:] + elif split_type == "exclude": + pid = info[0]["puzzle_id"] + if int(pid) in gv.SEQ_PUZZLES or int(pid) == 58: + # we don't do exclude splits for seq_puzzles are as they are most likely always unique + info = self.split_data(info, split_ratio, split_name, split_type="standard") + else: + ans_distr = [] + for t in range(len(info)): + ans_distr.append(info[t]["AnswerValue"]) + ans_distr = np.array(ans_distr) + bclassids = np.arange(gv.NUM_CLASSES_PER_PUZZLE[pid]) + x = np.histogram(ans_distr, bclassids)[0] + x = x / x.sum() + + # select reasonable answers. + valid_ans_idx = np.where(x > 0.01) + x_cls = bclassids[valid_ans_idx] + x = x[valid_ans_idx] + median_class = x_cls[x <= np.median(x)][-1] + try: + train_inst = np.array(info)[ans_distr != median_class] + test_inst = np.array(info)[ans_distr == median_class] + except: + print(pid) + pdb.set_trace() + + n = len(train_inst) + splits = np.array([int(spl) for spl in split_ratio.split(":")]) + splits[0] = splits[0] + splits[2] + splits = splits.cumsum()[:2] + + if split_name == "train": + st = 0 + en = int(np.floor(n * splits[0] / 100.0)) + info = train_inst[st:en].tolist() + elif split_name == "val": + st = int(np.ceil(n * splits[0] / 100.0)) + en = int(np.floor(n * splits[1] / 100.0)) + info = train_inst[st:en].tolist() + else: + info = test_inst.tolist() + else: + raise "Unknown puzzle split type!!" + + return info + + +class SMART_TrainData(SMART_Data): + def __init__(self, args, split): + super().__init__(args) + self.data_root = args.data_root + self.num_tot = args.data_tot # how many instances per puzzles should we use? + self.diff = args.train_diff + self.word_embed = args.word_embed + self.fewshot_K = args.fsK + self.qa_info = [] + train_pids = None + + puzzle_ids = ( + self.split_puzzles(args.puzzle_ids, args.split_ratio, split, args.split_type) + if args.split_type == "puzzle" + else args.puzzle_ids + ) + if args.split_type == "fewshot": + train_pids, fewshot_other_pids = self.split_fewshot_puzzles( + args.puzzle_ids, args.split_ratio, split, args.split_type + ) + for puzzle_id in puzzle_ids: + puzzle_root = puzzle_id + "/" + gv.puzzle_diff_str[self.diff] + "/" + csv_file = "puzzle_%s%s.csv" % (puzzle_id, gv.puzzle_diff[self.diff]) + qa_info = utils.read_csv(os.path.join(self.data_root, puzzle_root, csv_file), puzzle_id) + if args.split_type == "fewshot" and puzzle_id in fewshot_other_pids: + qa_info = qa_info[: self.fewshot_K] + else: + qa_info = qa_info[: self.num_tot] + for t in range(len(qa_info)): + qa_info[t]["AnswerValue"] = utils.get_val(qa_info[t], qa_info[t]["Answer"]) + self.qa_info = self.qa_info + self.split_data(qa_info, args.split_ratio, split, args.split_type) + gv.MAX_VAL = max(gv.MAX_VAL, gv.NUM_CLASSES_PER_PUZZLE[puzzle_id]) + if args.baselines: + self.baseline_perf = baselines.get_baseline_performance(args, self.qa_info, split, self.num_tot, log=True) + print("num_train=%d max_answer_value=%d" % (len(self.qa_info), gv.MAX_VAL)) + print("split=%s puzzle_ids=" % (split), end=" ") + print(puzzle_ids) + + def __getitem__(self, idx): + info = self.qa_info[idx] + pid = info["puzzle_id"] + puzzle_root = pid + "/" + gv.puzzle_diff_str[self.diff] + "/" + im = self.apply_transform(gv.osp(self.data_root, puzzle_root, "img", info["image"])) + qa = self.quest_encode(info["Question"]) + opts = 0 + lbl = self.ans_encode(info["Answer"]) + answer_value = info["AnswerValue"] + answer = np.zeros( + gv.MAX_DECODE_STEPS, + ) + if int(pid) not in gv.SEQ_PUZZLES: + answer[0] = answer_value + else: + try: + answer[: len(answer_value)] = answer_value + except: + print(info) + pdb.set_trace() + return im, torch.tensor(qa), torch.tensor(opts), torch.tensor(lbl), torch.tensor(answer), torch.tensor(int(pid)) + + def __len__(self): + return len(self.qa_info) + + +class SMART_ValData(SMART_Data): + def __init__(self, args, split): + super().__init__(args) + self.data_root = args.data_root + self.num_tot = args.data_tot + self.word_embed = args.word_embed + self.fewshot_K = args.fsK + self.qa_info = [] + + self.diff = args.test_diff if split == "test" else args.train_diff + puzzle_ids = ( + self.split_puzzles(args.puzzle_ids, args.split_ratio, split, args.split_type) + if args.split_type == "puzzle" + else args.puzzle_ids + ) + if args.split_type == "fewshot": + puzzle_ids, fewshot_other_pids = self.split_fewshot_puzzles( + args.puzzle_ids, args.split_ratio, split, args.split_type + ) + + for puzzle_id in puzzle_ids: + puzzle_root = puzzle_id + "/" + gv.puzzle_diff_str[self.diff] + "/" + csv_file = "puzzle_%s%s.csv" % (puzzle_id, gv.puzzle_diff[self.diff]) + qa_info = utils.read_csv(os.path.join(self.data_root, puzzle_root, csv_file), puzzle_id) + if args.split_type == "fewshot": + qa_info = qa_info[ + self.fewshot_K : self.num_tot + ] # we use the fewshot_K for training. so use the rest for evaluation. + else: + qa_info = qa_info[: self.num_tot] + for t in range(len(qa_info)): + qa_info[t]["AnswerValue"] = utils.get_val(qa_info[t], qa_info[t]["Answer"]) + self.qa_info = self.qa_info + self.split_data(qa_info, args.split_ratio, split, args.split_type) + gv.MAX_VAL = max(gv.MAX_VAL, gv.NUM_CLASSES_PER_PUZZLE[puzzle_id]) + print("num_val = %d max_answer_value=%d" % (len(self.qa_info), gv.MAX_VAL)) + if args.baselines: + self.baseline_perf = baselines.get_baseline_performance(args, self.qa_info, split, self.num_tot, log=True) + print("split=%s puzzle_ids=" % (split), end=" ") + print(puzzle_ids) + + def __getitem__(self, idx): + info = self.qa_info[idx] + pid = info["puzzle_id"] + puzzle_root = info["puzzle_id"] + "/" + gv.puzzle_diff_str[self.diff] + "/" + im = self.apply_transform(gv.osp(self.data_root, puzzle_root, "img", info["image"])) + qa = self.quest_encode(info["Question"]) + + _ = [utils.str_replace_(info, key) for key in ["A", "B", "C", "D", "E"]] + opts = [utils.get_val(info, key, is_one_of_option=True) for key in ["A", "B", "C", "D", "E"]] + lbl = self.ans_encode(info["Answer"]) + answer_value = info["AnswerValue"] + answer = np.zeros( + gv.MAX_DECODE_STEPS, + ) + if int(pid) not in gv.SEQ_PUZZLES: + answer[0] = answer_value + else: + answer[: len(answer_value)] = answer_value + return im, torch.tensor(qa), opts, torch.tensor(lbl), torch.tensor(answer), torch.tensor(int(info["puzzle_id"])) + + def __len__(self): + return len(self.qa_info) + + +def SMART_collate_fn(data): + """we use it only for val and test to load the options as a list""" + concat = lambda data_list: torch.cat([x.unsqueeze(0) for x in data_list]) + im, qa, opts, lbl, answer, puzzle_ids = zip(*data) + im = concat(im).float() + qa = concat(qa) + lbl = concat(lbl) + answer = concat(answer) + puzzle_ids = concat(puzzle_ids) + return im, qa, opts, lbl, answer, puzzle_ids diff --git a/dataset/SMART_info_v2.csv b/dataset/SMART_info_v2.csv new file mode 100644 index 0000000..48b0d96 --- /dev/null +++ b/dataset/SMART_info_v2.csv @@ -0,0 +1,102 @@ +id,Question,image,A,B,C,D,E,Answer option,difficulty,type,level,source,note: difficulty (1/2/3: easy/normal/hard) +1,Who caught the fish? ,mk2017_1.png,Adam,Basil,Charlie,David,Edgar,D,2,path,1,kangaroo/2017,8/26/2022 ready for review +2,"In the picture, there are stars with 5 points, stars with 6 points, and stars with 7 points. How many stars that have only 5 points are there? ",mk2017_2.png,2,3,4,5,9,C,1,counting,1,kangaroo/2017,8/12/2022 revision ready for review; sugestions: polygons + stars (done) +3,The entire pie seen in the picture is divided among several children. Each child receives a piece of pie with three cherries on top. How many children are there? ,mk2017_3.png,3,4,5,6,8,B,1,counting,1,kangaroo/2017,8/12/2022 ready for review; suggestions: change the text of the question (avoid ingredient) (done) +4,Into how many parts do the scissors cut the rope in the picture?,mk2017_4.png,5,6,7,8,9,A,2,counting,1,kangaroo/2017,9/2/2022 ready for review; Same as 97 (but could use a different curve function) +5,How many bricks are missing from the igloo?,mk2017_6.png,6,7,8,9,10,A,2,counting,1,kangaroo/2017,8/12/2022 ready for review +6,"Four of the numbers 1,3,4,5, and 7 are used, one in each square, so that the equality is correct. Which of the numbers is not used?",mk2017_8.png,1,3,4,5,7,C,1,math,1,kangaroo/2017,8/12/2022 revision ready for review; suggestions: replace different icons with some primitive shapes (done) +7,"In the country of jewelries, you can trade three sapphires for one ruby (picture 1). For one sapphire, you can get two flowers (picture 2). How many flowers can you get for two rubies? ",mk2017_9.png,6,8,10,12,14,D,1,algebra,1,kangaroo/2017,8/12/2022 revision ready for review; suggestions: use the same instance of each type (done) +8,How many triangles are there in the picture below?,mk2017_11.png,8,9,10,11,12,D,3,counting,1,kangaroo/2017,9/23/2022 ready for review; KP: haven't found a closed-form solution; suggestions: think about how to control the level of triangles? +9,"Brian and William are standing in line. Brian knows that there are 7 people in front of him. William knows that there is a total of 11 people in the line. If Brian is just in front of William, how many of the people in the line are behind William?",None,2,3,4,5,6,A,1,algebra,1,kangaroo/2017,8/5/2022 ready for review +10,"In the table, the correct additions were performed in the squares according to the pattern shown. What number should replace the question mark?",mk2017_17.png,10,11,12,13,15,B,1,algebra,1,kangaroo/2017,8/5/2022 ready for review +11,"Father hangs the laudry outside on a clothesline. He wants to use as few pins as possible. For 3 towels, he needs 4 pins, as shown. How many pins does he need for 9 towels?",mk2012_q3,9,10,12,16,18,B,1,math,1,kangaroo/2017,8/5/2022 ready for review +12,At the entrance of the zoo there are 12 children in line. Lucy is the 7th from the front and Sam is teh second from the back. How many children are there between Lucy and Sam in the line?,mk2019_q5.png,2,3,4,5,6,B,1,algebra,1,kangaroo/2019,8/12/2022 ready for review +13,Jorge pairs his socks so that the numbers match. How many pairs can he make?,mk2019_q6.png,3,4,5,6,8,C,1,counting,1,kangaroo/2019,8/12/2022 ready for review +14,Maya the Bee was collecting pollen from all of the flowers that are inside the rectangle but outside the triangle. From how many flowers did she collect pollen?,mk2019_q7.png,9,10,13,17,20,A ,2,counting,1,kangaroo/2019,9/16/2022 ready for review +15,Look at the picture and answer the question.,mk2019_q8,5,6,7,8,9,C,1,algebra,1,kangaroo/2019,8/19/2022 ready for review +16,You have to close two of the five gates so that the mouse cannot reach the cheese. Which gates should you close? ,mk2019_q9,1 and 2,2 and 3,3 and 4,3 and 5,4 and 5,E,3,path,1,kangaroo/2019,9/23/2022 ready for review; KP ref: https://rosettacode.org/wiki/Maze_generation +17,Patricia folds a sheet of paper twice and then cuts it as shown. How many pieces of paper does she end up with? ,mk2019_q10,2,3,4,5,6,B,2,spatial,1,kangaroo/2019,9/12/2022 ready for review +18,Five square cards are stacked on a table as shown. The cards are removed one by one from the top of the stack. In what order are the cards removed? ,mk2019_q11,5-2-3-1-4,5-2-3-4-1,4-5-2-3-1,5-3-2-1-4,1-2-3-4-5,A,2,order,1,kangaroo/2019,9/2/2022 ready for review; suggestions: makesure there's no duplicate pssible answer +19,A cat and a bowl of milk are in opposite corners of the board. The cat can only move as shown by the arrows. In how many ways can the cat reach the milk?,mk2019_q12.png,2,3,4,5,6,E,1,path,1,kangaroo/2019,8/12/2022 ready for review +20,A floor is covered with identical rectangular tiles as shown. The shorter side of each tile is 1m. What is the length of the side with the question mark?,mk2019_q13.png,6,8,10,11,12,E,2,measure,1,kangaroo/2019,9/16/2022 ready for review; suggestions: make sure that the ratio of tile width/height is observable from the image +21,A train from Kang station to Aroo station leaves at 6:00 in the morning and passes by three other stations without stopping. The numbers show the travel times between two stations in hours. The train arrives at Aroo station at 11:00 at night on the same day. What is the travel time between Aroo station and the previous station? ,mk2019_q14.png,2 hours ,3 hours,4 hours,5 hours,6 hours,D,1,measure,1,kangaroo/2019,8/12/2022 ready for review; sggestions: add possibility to move leftward (done) +22,Tim and Tom built a sandcastle and decorated it with a flag. They stuck half of the flagpole into the highest point of the castle. The upper tip of the flagpole was 80 cm above the ground and the lower tip was 20 cm above the ground. How tall was the sandcastle? ,mk2019_q15.png,40 cm,45 cm,50 cm,55 cm,60 cm,C,2,measure,1,kangaroo/2019,8/19/2022 ready for review +23,How many legs do these animals have altogether?,mk2012_q1,5,10,12,14,20,D,1,counting,1,kangaroo/2012,8/12/2022 revision ready for review; suggestions: replace imagenet sketch with icon50 data; fix the note part in csv (done) +24,Which of the bolded paths is the longest?,mk2012_q2,A,B,C,D,E,E,2,counting,1,kangaroo/2012,8/26/2022 ready for review +25,The clock shows the time when Stephen leaves school. Lunch at school starts 3 hours before school ends. At what time does lunch start? ,mk2012_q4.png,1:00,2:00,5:00,11:00,12:00,D,2,measure,1,kangaroo/2012,8/19/2022 ready for review +26,"Stars, clovers, gifts, and trees repeat regularly on a game board. Some juice was spilled on the board. As a result, some of the pictures cannot be seen. These are the white spaces in the picture below. How many stars were on the board before the juice was spilled? ",mk2012_q5,3,6,8,9,20,D,2,measure,1,kangaroo/2012,8/26/2022 ready for review +27,"Sparoow Jack jumps on a fence from one post to another. Each jump takes him 1 second. He makes 4 jumps ahead and then 1 jump back. Then he again makes 4 jumps ahead and 1 back, and so on. In how many seconds does Jack get from start to finish? ",mk2012_q6.png,10,11,12,13,14,E,1,algebra,1,kangaroo/2012,8/19/2022 ready for review; suggestions: add instances with fewer number of posts (learners need to count from images) +28,What number is covered by the flower?,mk2012_q7,1,2,3,4,5,D,1,algebra,1,kangaroo/2012,8/5/2022 ready for review +29,There are coins on the board. We want to have 2 coins in each column and 2 coins in each row. How many coins need to be removed?,mk2012_q8,0,1,2,3,4,C,1,counting,1,kangaroo/2012,8/26/2022 ready for review +30,"In a box, there are three boxes, and each one of these boxes countains three smaller boxes. How many boxes are there in total? ",,9,10,12,13,15,D,1,math,1,kangaroo/2012,8/5/2022 ready for review +31,"In Old McDonald's barn there is one horse, two cows, and three pigs. How many more cows does Old McDonald's barn need so that the numnber of all the animals is twice the number of cows? ",mk2017_q20,0,1,2,3,4,C,1,math,1,kangaroo/2017,8/5/2022 ready for review +32,Which letter on the board is not in the word - KOALA?,mk2016_q1,R,L,K,N,O,D,2,path,1,kangaroo/2016,"8/12/2022 ready for review; suggestions: capital letter; avoid ""N"" and ""Z"" , ""M"" and ""W"", ""I"" and ""H"" confusion (done)" +33,How many ropes are in the picture? ,mk2016_q2,2,3,4,5,6,B,2,counting,1,kangaroo/2016,8/26/2022 ready for review +34,Michael built a house using matches as shown in the picture. How many matches did he use?,mk2016_q3,19,18,17,15,13,D,1,counting,1,kangaroo/2016,9/12/2022 ready for review; suggestions: sample match colors from a template set +35,Which point of the labrynth can we reach starting from point O?,mk2016_q4,A,B,C,D,E,C,3,path,1,kangaroo/2016,9/23/2022 ready for review; KP ref: https://github.com/razimantv/mazegenerator; suggestions: consider easier complexity and make the figure pretty +36,"Lisa's hens lay white eggs and brown eggs. Lisa puts six eggs in the box shown below. Two brown eggs cannot touch each other. At most, how many brown eggs can Lisa put in the box? ",mk2016_q5,1,2,3,4,5,C,1,counting,1,kangaroo/2016,KP: 9/23/2022 ready for review; suggestions: make the box non rectangle (or rectangular box but not dense); reduce the number of slots; start from dense config and randomly remove egg slots +37,"In Baby Roo's house, each room is connected to any neighboring room by a door. Baby Roo wants to get from room A to room B. What is the least number of doors that he needs to go through?",mk2016_q6,3,4,5,6,7,B,3,path,1,kangaroo/2016,KP: 10/7/2022 ready for review +38,"There are twelve rooms in a building and each room has two windows and one light. Last evening, eighteen windows were lit. In how many rooms was the light off?",,2,3,4,5,6,B,1,math,1,kangaroo/2016,SL: Done +39,"Mary is walking along the road and she reads only the letters located on her right side. Moving from point 1 to point 2, what is the word she will read? ",mk2016_q7,KNAO,KNGO,KNR,ARGO,KAO,A,3,path,1,kangaroo/2016,KP: 10/7/2022 ready for review +40,Amy used six equal small squares to build the figure shown in the picture. What is the least number of equal small squares she should add to the picture in order to obtain a larger square? ,mk2016_q8,6,8,9,10,12,D,1,logic,1,kangaroo/2016,KP: 9/23/2022 ready for review +41,"Five sparrows sat on a wire as shown in the picture. Each sparrow chirped only once to each bird it saw on the side it faced. For example, the second sparrow chirped one time. In total, how many times did they chirp? ",mk2016_q9,6,8,9,10,12,D,1,counting,1,kangaroo/2016,"SL: first version done. suggestions: fix the capital issue at the beginning of the question. ""gobbleed"" ""only once"" | fixed errors" +42,There are five ladybugs shown to the left. How many spots are there on all the ladybugs together?,mk2015_q1,17,18,19,20,21,C,1,counting,1,kangaroo/2015,"SL: Done; suggestions: fix the space in the question; random shape (instead of just dots), translation, and rotation" +43,The picture shows six numbers. What is the sum of the numbers outside the square?,mk2015_q2,12,11,23,33,10,E,1,math,1,kangaroo/2015,KP: 9/23/2022 ready for review +44,Eric has 10 identical metal strips. He used screws to connect pairs of them together into five long strips. Which strip is the shortest? ,mk2015_q3,A,B,C,D,E,B,2,spatial,1,kangaroo/2015,KP: 10/7/2022 ready for review +45,Marth built six towers using gray cubes and white cubes as shown in the picture. She made each tower using five cubes. Cubes of the same color do not touch. How many white cubes did she use?,mk2015_q4,10,11,12,18,30,C,2,counting,1,kangaroo/2015,KP: 9/30/2022 ready for review +46,"Emil placed the numbers 1,2,3,4, and 5 correctly in the boxes in the diagram below. What number did he place in the box with the question mark?",mk2015_q5,1,2,3,4,5,E,1,math,1,kangaroo/2015,SL: Done +47,"Vera invited 13 guests to her birthday party. She had 2 pizzas, and each of them was cut into 8 slices. Each person at the party ate one slice of pizza. How many slices of pizza were left over? ",,5,4,3,2,1,D,1,algebra,1,kangaroo/2015,SL: Done; suggestions: check grammar in the questions (singar and plural) - fixed errors +48,"In one jump, Jake the Kangaroo jumps from one circle to the neighboring circle along a line, as shown in the picture. He cannot jump into any circle than once. He starts at circle S and needs to make exactly 4 jumps to get to circle F. In how many different ways can Jake do this? ",mk2015_q6,3,4,5,6,7,D,2,path,1,kangaroo/2015,SL: Done; suggestions: make the color bright; path in the note - suggestions incorporated +49,"The numbers 3,5,7,8, and 9 were written in the squares of the cross (see the picture) so that the sum of the numbers in the row is equal to the sum of the numbers in the column. Which number was written in the central square? ",mk2015_q7,3,5,7,8,9,D,1,algebra,1,kangaroo/2015,"SL: Done, ready for review; suggestions: consider putting ""?"" somewhere else (same as 63; optional)" +50,How many more small gray squares are there than small white squares?,mk2014_q1,6,7,8,9,10,D,1,counting,1,kangaroo/2014,SL: Done +51,Put the animals in order from the smallest to the largest. Give the number of the animal in the middle?,mk2014_q2,1,2,3,4,5,B,1,order,1,kangaroo/2014,KP: 10/7/2022 ready for review +52,"A square was made out of 25 small squares, but some of these small squares are now missing. How many small squares are missing?",mk2014_q3,6,7,8,10,12,D,1,counting,1,kangaroo/2014,KP: 9/23/2022 ready for review +53,The kangaroo is inside how many circles?,mk2014_q4,1,2,3,4,5,C,1,counting,1,kangaroo/2014,SL: Done. +54,"Walking from K to O along the lines, pickup the letters KANGAROO in the correct order. What is the length of the shortest walk in meters (1m = 1 meter)? ",mk2014_q5,16,17,18,19,20,C,1,path,1,kangaroo/2014,KP: 10/7/2022 ready for review +55,Seven sticks lie on top of each other. Stick 2 is at the bottom. Stick 6 is at the top. Which stick is in the middle? ,mk2014_q6,1,3,4,5,7,B,1,spatial,1,kangaroo/2014,KP: 10/7/2022 ready for review +56,How many frogs did the three pelicans catch altogether? ,mk2014_q7,1,2,4,9,12,D,1,logic,1,kangaroo/2014,KP: 9/30/2022 ready for review +57,The chess board is damaged. How many black squares are missing on the right side of the line? ,mk2014_q8,11,12,13,14,15,B,1,counting,1,kangaroo.2014,KP: 9/30/2022 ready for review +58,What should you put in the square on the bottom to get a correct diagram?,mk2014_q9,-38,/8,-45,x6,/6,E,1,math,1,kangaroo/2014,KP: 9/30/2022 ready for review +59,"Put the digits 2,3,4, and 5 in the squares and calculate the sum to get the largest possible value. What is that value?",mk2014_q10,68,77,86,95,97,D,1,math,1,kangaroo/2014,KP: 9/30/2022 ready for review +60,"To get the product of 2 x 3 x 15, Bill has to press the keys of his calculator seven times; see picture. Bill wants to multiply all the numbers from 3 to 21 using his calculator. At least how many times will he press the keys of his calculator? ",mk2014_q11,19,31,37,50,60,D,1,counting,1,kangaroo/2014,"SL: Done, image path fixed" +61,Each participant in a cooking contest baked one tray of cookies like the rectangular one shown. What is the smallest number of trays of cookies needed to make a plate like the oval one shown below? ,mk2021_q10,1,2,3,4,5,C,1,algebra,1,kangaroo/2021,AC: Done +62,"Stan has five toys: a ball, a set of blocks, a game, a puzzle, and a car. He puts each toy on a different shelf of the bookase. The ball is higher than the blocks and lower than the car. The game is directly above the ball. On which shelf can the puzzle not be placed?",mk2021_q11,1,2,3,4,5,C,1,logic,1,kangaroo/2021,AC: Done +63,"Roo wrote each of the numbers 1,2,3,4, and 5 in one of the circles in such a way that the sum of the numbers in the row is equal to the sum of the numbers in the column. What number can be written in the circle with the question mark?",mk2020_q1,only 5,"2,3,or 4",only 3,only 1 or 3,"1,3,or 5",E,1,algebra,1,kangaroo/2020,"SL: Done, ready for review; suggestions: consider putting ""?"" somewhere else" +64,This pizza was divided into equal parts. How many parts have been taken? (v2),mk2018_q3,1,2,3,4,5,D,1,logic,1,kangaroo/2018,"8/12/2022 ready for review; KP: for Anoop's implementation, since the texture of each slice is not the same, the question can be confusing though" +65,"The braid in the figure is made using three threads. One thread is green, one is blue, and one is red. What colors are the three threads?",mk2020_q3,"1 is blue, 2 is green, and 3 is red","1 is green, 2 is red, 3 is blue","1 is red, 2 is blue, and 3 is green","1 is green, 2 is blue, and 3 is red","1 is blue, 2 is red, 3 is green",D,3,path,1,kangaroo/2020,KP: not sure how to generate the figure +66,"An arrow pointing from one person to another means that the first person is taller than the second. For example, person B is taller than person A. Who is the shortest? ",mk2020_q4,person A,person B,person C,person D,person E,C,1,path,1,kangaroo/2020,AC: Done; suggestions: fix the overlap between label and monkey +67,The kangaroo goes up 3 steps each time the rabbit goes down 2 steps. On which step do they meet?,mk2020_q5,3,4,5,6,7,D,2,math,1,kangaroo/2020,KP: 10/7/2022 ready for review +68,"This card is lying on the table (see picture 1). It is flipped over its top edge and then flipped over its left edge, as shown in picture 2. What does the card look like after the two flips? ",mk2020_q6,A,B,C,D,E,B,1,spatial,1,kangaroo/2020,AC: Done +69,Jose has two cards of the same size. Card 1 has four holes cut out. Jose places card 1 directly on top of card 2. What does Jose see? ,mk2020_q7,A,B,C,D,E,A,1,spatial,1,kangaroo/2020,AC:Done +70,Tom has 9 cards as shown in Figure 1. He puts the cards on the board so that each horizontal line and each vertical line contains three cards with three different shapes and three different numbers of shapes. He has already placed three cards as shown in Figure 2. Which card does he put on the gray square? ,mk2020_q8,A,B,C,D,E,D,1,spatial,1,kangaroo/2020,AC:Done +71,"Two identical trains, each with 31 cars, are traveling in opposite directions. When car number 19 of one train is opposite car number 19 of the other, which car is opposite car number 12?",mk2020_q9,7,12,21,26,31,D,1,math,1,kangaroon/2020,AC:Done +72,"Mary wants to write numbers 1,2,3,4,5, and 6 inside the six squares of the figure. She want a different number in each square. She wants both the sum of the numbers in the blue squares and the sum of the numbers in the yellow squares to be 10. What number must she write in the square with the question mark?",mk2020_q10,1,2,3,4,5,A,1,algebra,1,kangaroo/2020,AC:Done +73,"A village with 12 houses has four straight roads and four circular roads. The map shows 11 of the houses. On each straight road there are 3 houses. On each circular road, there are also 3 houses. Where on the map should the 12th house be put? ",mk2020_q11,A,B,C,D,E,C,2,path,1,kangaroo/2020,AC:Done +74,"Six different numbers chosen from 1 to 9 are written on the faces of a cube, one number on each face. The sums of the numbers on each pair of opposite faces are equal. Which number could be on the face opposite the face with the number 5?",mk2020_q12,3,5,6,7,9,C,1,logic,1,kangaroo/2020,AC:Done +75,Mary made a shape using some white cubes and 14 gray cubes. How many of these gray cubes cannot be seen in the picture? ,mk2020_q13,1,3,5,6,8,D,1,spatial,1,kangaroo/2020,AC:Done +76,A number is written on each petal of two flowers. One petal is hidden. The sums of the numbers on the two flowers are equal. What number is written on the hidden petal? ,mk2020_q14,0,3,5,7,1,C,2,algebra,1,kangaroo/2020,AC:Done +77,What do you get when you switch the colors?,mk2018_q1,A,B,C,D,E,E,1,pattern,1,kangaroo/2018,AC: Done +78,Mary had some 4-ray stars like the one shown. She glued them together as shown in the picture on the bottom. At least how many stars did she use?,mk2018_q2,5,6,7,8,9,D,1,pattern,1,kangaroo/2018,AC: Done +79,This pizza was divided into equal parts. How many parts have been taken?,mk2018_q3,1,2,3,4,5,D,1,counting,1,kangaroo/2018,"8/12/2022 ready for review; KP: for Anoop's implementation, since the texture of each slice is not the same, the question can be confusing though" +80,How many kangaroos must be moved from one park to the other in order to get the same number of kangaroos in each park?,mk2018_q4,4,5,6,8,9,B,1,counting,1,kangaroo/2018,AC: Done +81,Little Theodor assembled a stacking toy as in the picture. How many rings will he see when looking at it from above? ,mk2018_q5,1,2,3,4,5,C,1,logic,1,kangaroo/2018,AC: Done +82,"Juana, the friendly witch, has 5 broomsticks in her garage. Each broomstick is marked with a letter at the end of its handle. Juana removes the broomsticks one by one without moving the others. Which broomstick will she remove last? ",mk2018_q6,A,B,C,D,E,B,2,order,1,kangaroo/2018,AC: Done; KP: the images and htmls are not named correctly in the shared folder +83,"Peter drew a pattern twice, as in the picture. Which point will he reach when he draws the third pattern? ",mk2018_q7,A,B,C,D,E,D,1,pattern,1,kangaroo/2018,AC: Done +84,"Lisa has 4 puzzle pieces, but she only needs 3 for her puzzle frame. Which one will be left over?",mk2018_q8,A,B,C,D,C or D,A,2,pattern,1,kangaroo/2018,AC: Done; suggestion: consider removing the idash line of some puzzle pieces +85,"On her first turn, Diana got 6 points with three arrows on the target, as shown in the left part of the picture. On her second turn, she got 8 points, as shown in the middle picture. How many points did she get on her third turn?",mk2018_q9,8,10,12,14,16,C,2,algebra,1,kangaroo/2018,AC: Done +86,How many times does a right hand appear in the picture? ,mk2018_q10,3,4,5,6,7,C,1,pattern,1,kangaroo/2018,AC: Done +87,"The number of dwarfs that can fit under a mushroom is equal to the number of dots on the mushroom cap. The picture below shows one side of each mushroom. The number of dots on the other side is the same. If 30 dwarfs are seeking shelter from the rain, how many dwarfs will get wet?",mk2018_q11,2,3,4,5,6,A,1,logic,1,kangaroo/2018,AC: Done; suggestions: make sure that the dot is larger than s specific size +88,1 ice cream cone costs 1 dollar. There is a sale so you can buy 6 ice cream cones for 5 dollars. How many ice cream cones at most can you buy with 36 dollars?,mk2018_q12,36,30,42,43,45,D,1,math,1,kangaroo/2018,AC: Done +89,"How many different numbers greater than 10 and smaller than 25 with all different digits can we make by using the digits 2, 0, 1, and 8?",None,4,5,6,7,8,A,1,counting,1,kangaroo/2018,AC: Done +90,"A pirate has two chests. There are 10 coins in the chest on the left and the other chest is empty. Starting tomorrow, the pirate will put 1 coin in the chest on the left and 2 coins in the chest on the right every day. In how many days will the two chests have the same number of coins? ",mk2018_q13,5,8,10,12,never,C,1,math,1,kangaroo/2018,AC: Done +91,"Alice has 3 white, 2 black, and 2 gray pieces of paper. She cuts every non-black piece of paper in half. Then she cuts every non-white piece of paper in half. How many pieces of paper will she have? ",None,14,16,17,18,20,D,1,logic,1,kangaroo/2018,AC: Done +92,"A student had some sticks with a length of 5 cms and a width of 1 cm. Using the sticks, he made the fence below. What is the length of the fence? ",mk2018_q14,20cm,21cm,22cm,23cm,25cm,B,2,measure,1,kangaroo/2018,AC: Done; suggestions: consider putting <--?--> at the bottom +93,The road from Anna's house to Mary's house is 16 km long. The road from Mary's house to John's house is 20 km long and the road from the crossroad to Mary's house is 9 km long. How long is the road from Anna's house to John's house?,mk2018_q15,7 km,9 km,11 km,16 km,18 km,E,1,measure,1,kangaroo/2018,AC: Done; suggestions: (instance_id=16) consider putting each house close to the canvas border to avoid occlusion of the path; overlap of name/number/figure +94,The picture shows two mushrooms. What is the difference between their heights?,mk2021_q1,4,5,6,11,17,B,1,measure,1,kangaroo/2021,AC: Done; suggestions: consider running the code with larger pool of images +95,Four identical pieces of paper are placed as shown. Michael wants to punch a hole that goes through all four pieces. At which point should Michael punch the hole?,mk2021_q2,A,B,C,D,E,D,1,spatial,1,kangaroo/2021,AC: Done +96,These children are standing in a line. Some are facing forward and others are facing backward. How many children are holding another child's hand with their right hand?,mk2021_q4,2,3,4,5,6,E,1,spatial,1,kangaroo/2021,AC: Done; suggestons: consider N chains of children; margin between children such that their hands touch +97,Edmund cut a ribbon as shown in the picture. How many pieces of the ribbon did he end up with?,mk2021_q5,9,10,11,12,13,D,1,counting,1,kangaroo/2021,AC: Done; suggestions: move the scissors' location such that they don't touch the rope +98,Rose the cat walks along the wall. She starts at point B and follows the diretion of the arrows shown in the picture. The cat walks a total of 20 meters. Where does she end up? ,mk2021_q6,A,B,C,D,E,D,1,math,1,kangaroo/2021,AC: Done; suggestions: check the alignment of images and questions; image copy problem +99,"Julia has two pots with flowers, as shown. She keeps the flowers exactly where they are. She buys more flowers and puts them in the pots. After that, each pot has the same number of each type of flower. What is the smallest number of flowers she needs to buy? ",mk2021_q7,2,4,6,8,10,C,1,counting,1,kangaroo/2021,AC: Done; suggestons: make sure that the icons do not occlude much +100,"Tom encodes words using the board shown. For example, the word PIZZA has the code A2 A4 C1 C1 B2. What word did Tom encode as B3 B2 C4 D2?",mk2021_q8,MAZE,MASK,MILK,MATE,MATH,E,1,logic,1,kangaroo/2021,AC: Done +101,The cards shown are placed into two boxes. The sums of the numbers in each box are the same. Which number must be in the box with the number 4?,mk2021_q9,only 3,only 5,only 6,only 5 or 6,impossible,C,1,algebra,1,kangaroo/2021,AC: Done diff --git a/dataset/icon-classes.txt b/dataset/icon-classes.txt new file mode 100644 index 0000000..6c1de12 --- /dev/null +++ b/dataset/icon-classes.txt @@ -0,0 +1,51 @@ +airplane +arrow_directions +ball +biking +bird +blade +boat +books +building +bunny_ears +cartwheeling +clock +cloud +disk +drinks +emotion_face +envelope +family +fast_train +feline +flag +flower +footwear +golfing +hand +hat +heart +holding_hands +japanese_ideograph +kiss +lock +mailbox +marine_animals +medal +money +monkey +moon +mountain +numbers +phone +prohibit_sign +README.txt +star +surfing +tree +umbrella +vehicle +water_polo +worker +wrestling +writing_utensil diff --git a/globvars.py b/globvars.py new file mode 100644 index 0000000..a00535b --- /dev/null +++ b/globvars.py @@ -0,0 +1,125 @@ +#!/usr/bin/env python3 +# Copyright (c) 2023 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: MIT +# +import os +import pdb + +import nltk +import numpy as np +import torch + +import utils + + +class GPT2: + # https://github.com/huggingface/transformers/issues/1458 + def __init__(self): + super(GPT2, self).__init__() + from transformers import GPT2LMHeadModel, GPT2Tokenizer + + self.model = GPT2LMHeadModel.from_pretrained("gpt2").to("cuda") + self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2") + self.word_dim = 768 + + def embeds(self, word_tk): + tkidx = self.tokenizer.encode(word_tk, add_prefix_space=True) + emb = self.model.transformer.wte.weight[tkidx, :] + return emb # .numpy() + + def get_word_dim(self): + return self.word_dim + + def word_embed(self, sentence): + with torch.no_grad(): + tokens = nltk.tokenize.word_tokenize(sentence.lower()) + word_feats = torch.row_stack([self.embeds(tk) for tk in tokens]) + return word_feats + + +class BERT: + # https://huggingface.co/docs/transformers/model_doc/bert + def __init__(self): + super(BERT, self).__init__() + from transformers import BertModel, BertTokenizer + + self.model = BertModel.from_pretrained("bert-base-uncased").to("cuda") + self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") + self.word_dim = 768 + + def get_word_dim(self): + return self.word_dim + + def word_embed(self, sentence): + with torch.no_grad(): + inputs = self.tokenizer(sentence, return_tensors="pt", padding=True).to("cuda") + outputs = self.model(**inputs) + word_feats = outputs.last_hidden_state + return torch.tensor(word_feats.squeeze()).cuda() + + +class GloVe: + def __init__(self): + super(GloVe, self).__init__() + import torchtext + + self.model = torchtext.vocab.GloVe(name="6B", dim=300) + self.word_dim = 300 + + def get_word_dim(self): + return self.word_dim + + def word_embed(self, sentence): + tokens = nltk.tokenize.word_tokenize(sentence.lower()) + word_feats = np.row_stack([self.model[tk] for tk in tokens]) + return torch.tensor(word_feats).cuda() + + +def globals_init(args): + global puzzle_diff, puzzle_diff_str, osp, rand, MAX_VAL, MAX_DECODE_STEPS, max_qlen + global num_puzzles, seed, icon_class_ids, signs + global SEQ_PUZZLES, NUM_CLASSES_PER_PUZZLE, device, SMART_DATASET_INFO_FILE + global word_dim, word_embed + global puzzles_not_included, num_actual_puzz + global PS_VAL_IDX, PS_TEST_IDX + + device = "cuda" + puzzle_diff = {"easy": ""} # {'easy': 'e', 'medium': 'm', 'hard': 'h'} + puzzle_diff_str = {"easy": ""} + osp = os.path.join + rand = lambda: np.random.rand() > 0.5 + MAX_VAL = 0 + MAX_DECODE_STEPS = 10 # number of steps to decode the LSTM. + num_puzzles = 101 + max_qlen = 110 + seed = 10 + icon_dataset_path = "./dataset/icon-classes.txt" #'/homes/cherian/train_data/NAR/SMART/SMART_cpl/puzzles/anoops/resources/icons-50/Icons-50/' + icon_class_ids = utils.get_icon_dataset_classes(icon_dataset_path) # os.listdir(icon_dataset_path) # puzzle 1 + signs = np.array(["+", "-", "x", "/"]) # puzzle 58 + NUM_CLASSES_PER_PUZZLE = {} + SEQ_PUZZLES = [16, 18, 35, 39, 63, 100] + SMART_DATASET_INFO_FILE = "./dataset/SMART_info_v2.csv" + num_actual_puzz = 102 + puzzles_not_included = set([]) + PS_VAL_IDX = [7, 43, 64] + PS_TEST_IDX = [94, 95, 96, 97, 98, 99, 101, 61, 62, 65, 66, 67, 69, 70, 71, 72, 73, 74, 75, 76, 77] + + if not os.path.exists(args.save_root): + os.makedirs(args.save_root) + + # if gpt2 + if args.word_embed == "glove": + Embed = GloVe() + word_dim = Embed.get_word_dim() + word_embed = Embed.word_embed + elif args.word_embed == "gpt": + Embed = GPT2() + word_dim = Embed.get_word_dim() + word_embed = Embed.word_embed + elif args.word_embed == "bert": + Embed = BERT() + word_dim = Embed.get_word_dim() + word_embed = Embed.word_embed + else: + print("word embedding used is %s" % (args.word_embed)) diff --git a/images/smart101-banner2.png b/images/smart101-banner2.png new file mode 100644 index 0000000..bac15ee Binary files /dev/null and b/images/smart101-banner2.png differ diff --git a/losses.py b/losses.py new file mode 100644 index 0000000..9875251 --- /dev/null +++ b/losses.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +# Copyright (c) 2023 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: MIT +# +import torch.nn as nn + +import globvars as gv + + +class Criterion(nn.Module): + def __init__(self, args): + super(Criterion, self).__init__() + self.monolithic = args.monolithic # just one classifier + self.loss_type = args.loss_type + if args.loss_type == "classifier": + self.criterion = nn.CrossEntropyLoss() + elif args.loss_type == "regression": + self.criterion = nn.L1Loss() + + def compute_loss(self, a, b, pids): + if self.monolithic: + loss = self.criterion(a, b[:, 0]) + else: + loss = 0 + for key in a.keys(): + idx = pids == int(key) + if int(key) not in gv.SEQ_PUZZLES: + loss += self.criterion( + a[key], b[idx, 0] + ) # risky if idx and key entries are not matched. but then we will encouter an exception. + else: + seq_loss = 0 + for i in range(len(a[key])): + seq_loss += self.criterion(a[key][i], b[idx, i]) # .long() + seq_loss /= len(a[key]) + loss += seq_loss + loss = loss / len(a.keys()) + return loss + + def forward(self, a, b, pids=None): + if self.loss_type == "classifier": + loss = self.compute_loss(a, b.long(), pids) + elif self.loss_type == "regression": + loss = self.compute_loss(a, b.float(), pids) + else: + raise "Unknown loss type: use classifer/regression" + return loss diff --git a/main.py b/main.py new file mode 100644 index 0000000..e58f127 --- /dev/null +++ b/main.py @@ -0,0 +1,404 @@ +#!/usr/bin/env python3 +# Copyright (c) 2023 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: MIT +# +import os + +import numpy as np +import torch + +os.environ["TOKENIZERS_PARALLELISM"] = "1" + +import warnings + +warnings.filterwarnings("ignore") +import argparse +import copy +import time + +import torch.nn.functional as F +from tqdm import tqdm + +import build_vocab as vocab_utils +import data_loader as dl +import globvars as gv +import losses +import net +import utils + + +def reset_state(args): + # global seed + gv.seed = np.random.randint(10000) if args.seed == -1 else args.seed + args.seed = gv.seed + manualSeed = gv.seed # + np.random.seed(manualSeed) + torch.manual_seed(manualSeed) + torch.cuda.manual_seed(manualSeed) + torch.cuda.manual_seed_all(manualSeed) + torch.backends.cudnn.deterministic = True + print("seed = %d" % (gv.seed)) + + +def train(args, dataloader, im_backbone): + criterion = losses.Criterion(args) + if args.model_name == "flava": + model = net.SMART_VL_Net(args, VL_backbone=im_backbone) + elif args.model_name == "clip": + import net_clip + + model = net_clip.SMART_VL_CLIP_Net(args, VL_backbone=im_backbone) + else: + model = net.SMART_Net(args, im_backbone=im_backbone) + + model = model.cuda() + parameters = model.parameters() + if not args.no_meta: + anshead_parameters = list(model.ans_decoder.parameters()) + + def normalize(err, pids): + """this function divides the error by the gt number of classes for each puzzle.""" + pids = np.array(pids) + for t in range(len(err)): + err[t] = err[t] / gv.NUM_CLASSES_PER_PUZZLE[str(pids[t])] + return err + + def get_result(out, ltype): + if ltype == "classifier": + pred_max = F.softmax(out, dim=1).argmax(dim=1).cpu() + elif ltype == "regression": + pred_max = torch.floor(out).long().cpu()[:, 0] + else: + raise "unknown loss type" + + return pred_max + + def save_model(args, net, acc, epoch, location): + state = { + "net": net.state_dict(), + "acc": acc, + "epoch": epoch, + } + if not os.path.isdir(location): + os.mkdir(location) + loc = os.path.join(location, "ckpt_%s_%s_%s.pth" % (args.model_name, args.word_embed, args.seed)) + print("saving checkpoint at %s" % (loc)) + torch.save(state, loc) + + def train_loop(epoch, train_loader, optimizer): + model.train() + tot_loss = 0.0 + for i, (im, q, _, a, av, pids) in tqdm(enumerate(train_loader)): + im = im.cuda() + q = q.cuda() + a = a.cuda() + av = av.cuda() + if args.no_meta: + out = model(im, q, puzzle_ids=pids) + loss = criterion(out, av, pids) + optimizer.zero_grad() + loss.backward() + optimizer.step() + else: + # meta learning updates. + loss_list = [None] * args.num_meta_updates + for k in range(args.num_meta_updates): + out = model(im, q, puzzle_ids=pids) + loss = criterion(out, av, pids) + anshead_optimizer.zero_grad() + grad = torch.autograd.grad(loss, anshead_parameters, allow_unused=True, retain_graph=True) + for (gr, pr) in zip(grad, anshead_parameters): + if gr is not None: + pr = pr - args.lr * gr + loss_list[k] = loss # the last loss. + meta_loss = loss_list[-1] / args.num_meta_updates + optimizer.zero_grad() + meta_loss.backward() + optimizer.step() # meta update. + tot_loss += loss.item() + + tot_loss /= float(i) + return tot_loss + + def val_loop(val_loader, model): + model.eval() + acc_mean = 0 + cnt = 0 + err_mean = 0 + opt_mean = 0 + puzzle_acc = {} + with torch.no_grad(): + for i, (im, q, o, a, av, pids) in enumerate(val_loader): + im = im.cuda() + q = q.cuda() + o = np.array(o) + out = model(im, q, puzzle_ids=pids) + + if not args.monolithic: + upids = torch.unique(pids) + acc = 0 + error = 0 + opts_acc = 0 + for t in upids: + idx = pids == t + tt = t.item() + + if t not in gv.SEQ_PUZZLES: + pred_max = get_result(out[str(tt)], args.loss_type) + pacc = (pred_max == av[idx, 0]).sum() + perror = normalize(np.abs(pred_max - av[idx, 0]), pids).sum() + oacc = utils.get_option_sel_acc(pred_max, o[idx], a[idx], av[idx], t).sum() + else: + pred_ans = [] + pacc = 1 + for k in range(gv.MAX_DECODE_STEPS): + pred_max = get_result(out[str(tt)][k], args.loss_type) + pred_ans.append(pred_max) + pacc = pacc * (pred_max == av[idx][:, k]) + pacc = pacc.sum() + perror = 0 + oacc = utils.get_option_sel_acc(np.column_stack(pred_ans), o[idx], a[idx], av[idx], t).sum() + + if str(tt) in puzzle_acc.keys(): + puzzle_acc[str(tt)][0] += pacc + puzzle_acc[str(tt)][1] += oacc + puzzle_acc[str(tt)][2] += idx.sum() + else: + puzzle_acc[str(tt)] = [pacc, oacc, idx.sum()] + # we use the ansewr value here. + opts_acc += oacc + acc += pacc + error += perror + else: # for monolothic architecture, i.e. using only one output head (e.g., in puzzle/FS split) + av = av[:, 0] + if args.loss_type == "classifier": + pred = F.softmax(out, dim=1) + pred_max = pred.argmax(dim=1).cpu() + elif args.loss_type == "regression": + pred_max = torch.floor(out).long().cpu() + + acc = (pred_max == av).float().sum() + opt = utils.get_option_sel_acc(pred_max, o, a, av, -1) + opts_acc = opt.sum() + error = normalize(torch.abs(pred_max - av).float(), pids).sum() + + # compute accuracy per puzzle.() + for t in [int(s) for s in pids]: + if str(t) in puzzle_acc.keys(): + puzzle_acc[str(t)][0] += (pred_max == av)[pids == t].sum() + puzzle_acc[str(t)][1] += opt[pids == t].sum() + puzzle_acc[str(t)][2] += (pids == t).sum() + else: + puzzle_acc[str(t)] = [ + (pred_max == av)[pids == t].sum(), + opt[pids == t].sum(), + (pids == t).sum(), + ] + + opt_mean += opts_acc + acc_mean += acc + err_mean += error + cnt += len(av) + + return acc_mean / float(cnt), err_mean / float(cnt), opt_mean / float(cnt), puzzle_acc + + def test_loop(test_loader, model): + acc, err, opt, puzzle_acc = val_loop(test_loader, model) + utils.print_puzz_acc(args, puzzle_acc, log=True) + print( + "***** Final Test Performance: S_acc = %0.2f O_acc = %0.2f Prediction Variance = %0.2f " + % (acc * 100, opt * 100, err) + ) + + if args.test: + net.load_pretrained_models(args, args.model_name, model=model) + test_loop(dataloader["test"], model) + return + + if args.optimizer == "adam": + optimizer = torch.optim.Adam(parameters, lr=args.lr, betas=(0.9, 0.99)) + if not args.no_meta: + anshead_optimizer = torch.optim.Adam(anshead_parameters, lr=args.lr, betas=(0.9, 0.99)) + else: + optimizer = torch.optim.SGD(parameters, lr=args.lr) + if not args.no_meta: + anshead_optimizer = torch.optim.SGD(anshead_parameters, lr=args.lr) + + train_loader = dataloader["train"] + val_loader = dataloader["valid"] + test_loader = dataloader["test"] + + # training loop + best_model = None + best_acc = 0 + no_improvement = 0 + num_thresh_epochs = 20 + # stop training if there is no improvement after this. + print("starting training...") + for epoch in range(args.num_epochs): + tt = time.time() + model.train() + loss = train_loop(epoch, train_loader, optimizer) + tt = time.time() - tt + + if epoch % 1 == 0: + model.eval() + acc, err, oacc, puz_acc = val_loop(val_loader, model) + if acc >= best_acc: + best_epoch = epoch + best_acc = acc + best_model = copy.deepcopy(model) + save_model(args, best_model, acc, epoch, args.location) + no_improvement = 0 + else: + no_improvement += 1 + if no_improvement > num_thresh_epochs: + print("no training improvement... stopping the training.") + utils.print_puzz_acc(args, puz_acc, log=args.log) + break + if epoch % args.log_freq == 0: + print( + "%d) Time taken=%f Epoch=%d Train_loss = %f S_acc = %f O_acc=%f Variance = %f Best S_acc (epoch) = %f (%d)\n" + % (gv.seed, tt, epoch, loss, acc * 100, oacc * 100, err, best_acc * 100, best_epoch) + ) + utils.print_puzz_acc(args, puz_acc, log=args.log) + + if epoch % args.log_freq == 0: + acc, err, oacc, puz_acc = val_loop(test_loader, model) + print( + "puzzles %s: val: s_acc/o_acc/var = %f/%f/%f (%d)" + % (args.puzzles, acc * 100, oacc * 100, err, best_epoch) + ) + + test_loop(test_loader, best_model) + + +def get_data_loader(args, split, batch_size=100, shuffle=True, num_workers=6, pin_memory=True): + if split == "train": + dataset = dl.SMART_TrainData(args, split) + collate_fn = None + else: + dataset = dl.SMART_ValData(args, split) + collate_fn = dl.SMART_collate_fn + data_loader = torch.utils.data.DataLoader( + dataset=dataset, + batch_size=batch_size, + shuffle=shuffle, + num_workers=num_workers, + pin_memory=pin_memory, + collate_fn=collate_fn, + ) + return data_loader + + +if __name__ == "__main__": + device = "cuda" + + parser = argparse.ArgumentParser(description="SMART dataset") + parser.add_argument( + "--puzzles", default="all", type=str, help="comma separated / all / puzzle groups (counting,math etc.)" + ) + parser.add_argument("--batch_size", default=64, type=int, help="batch size (16)") + parser.add_argument("--num_epochs", default=100, type=int, help="epoch") + parser.add_argument("--lr", default=0.001, type=float, help="learning rate (0.001)") + parser.add_argument("--test_file", type=str, help="csv file for train") + parser.add_argument( + "--data_root", + type=str, + default="/homes/cherian/train_data/NAR/SMART/SMART_cpl/VLPS_v2/224/", + help="location of the csv files, and location of the images, relative location is provided in the csv file.", + ) + parser.add_argument("--train_diff", type=str, default="easy", help="easy/medium/hard") + parser.add_argument("--test_diff", type=str, default="easy", help="easy/medium/hard") + parser.add_argument( + "--split_ratio", + type=str, + default="80:5:15", + help="how to split train and val, when both use the same instance list.", + ) + parser.add_argument("--save_root", type=str, default="./data/v2/", help="location to save intermediate files.") + parser.add_argument("--vocab_path", type=str, default="none", help="location to save intermediate files.") + parser.add_argument("--num_workers", type=int, default=16, help="number of workers") + parser.add_argument("--pretrained", type=str, help="should use a pretrained model?") + parser.add_argument("--optimizer", type=str, default="adam", help="optimizer to use") + parser.add_argument("--loss_type", type=str, default="regression", help="classifier/regression") + parser.add_argument("--model_name", type=str, help="model to use resnet50/resnet18/...") + parser.add_argument("--seed", type=int, default=-1, help="seed to use") + parser.add_argument("--data_tot", type=int, default=2000, help="how many instances to use for train+val+test") + parser.add_argument("--use_clip_text", action="store_true", help="should use clip text embeddings?") + parser.add_argument("--no_meta", action="store_true", help="do not use meta learning for optimization?") + parser.add_argument("--log", action="store_true", help="should print detailed log of accuracy?") + parser.add_argument("--baselines", action="store_true", help="run the baselines from answer distributions?") + parser.add_argument( + "--monolithic", action="store_true", help="use a single head for all puzzles (except the sequential ones)?" + ) + parser.add_argument( + "--split_type", type=str, default="standard", help="type of data split: stanard/exclude/puzzle/fewshot" + ) + parser.add_argument("--word_embed", type=str, default="standard", help="standard/gpt/glove") + parser.add_argument( + "--use_single_image_head", action="store_true", help="use a single image head for all the puzzles?" + ) + parser.add_argument( + "--fsK", type=int, default=100, help="how many samples should we use to train in a fewshot setting?" + ) + parser.add_argument("--log_freq", type=int, default=50, help="log frequency?") + parser.add_argument("--test", action="store_true", help="evaluate a model?") + parser.add_argument("--train_backbone", action="store_true", help="train the image backbone?") + parser.add_argument("--no_question", action="store_true", help="do not use questions?") + parser.add_argument("--no_image", action="store_true", help="do not use images?") + parser.add_argument("--num_meta_updates", type=int, default=1, help="number of meta updates?") + parser.add_argument( + "--feat_size", type=int, default=128, help="intermediate feature size for image and language features?" + ) + + args = parser.parse_args() + + if args.split_type == "puzzle": # use only a single head and single output head for PS. + args.monolithic = True + args.use_single_image_head = True + args.no_meta = True # we do not use meta learning for puzzle split. + + if args.monolithic: # in this case, we use a single output head, but do not include sequential puzzles. + args.no_meta = True + + if args.test: + assert args.seed > -1 # when evaluating we need to use the seed to take the checkpoint. + + gv.globals_init(args) + + args.puzzle_ids_str, args.puzzle_ids = utils.get_puzzle_ids(args) + args.location = os.path.join(args.save_root, "checkpoints") + args.log_path = os.path.join(args.save_root, "log") + + reset_state(args) + gv.NUM_CLASSES_PER_PUZZLE = utils.get_puzzle_class_info( + args + ) # initialize the global with the number of outputs for each puzzle. + + vocab = vocab_utils.process_text_for_puzzle(args) + if args.vocab_path == "none": + args.vocab_path = os.path.join(args.save_root, "vocab_puzzle_" + args.puzzle_ids_str + ".pkl") + + im_backbone, preprocess = net.load_pretrained_models(args, args.model_name, model=None) + args.preprocess = preprocess + + train_loader = get_data_loader( + args, "train", batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers + ) + val_loader = get_data_loader(args, "val", batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers) + test_loader = get_data_loader(args, "test", batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers) + + dataloader = { + "train": train_loader, + "valid": val_loader, + "test": test_loader, + } + + utils.backup_code_and_start_logger(args, args.log_path, args.seed) + + print(args) + print("num_puzzles=%d" % (len(args.puzzle_ids))) + + train(args, dataloader, im_backbone) diff --git a/net.py b/net.py new file mode 100644 index 0000000..f57a65b --- /dev/null +++ b/net.py @@ -0,0 +1,540 @@ +#!/usr/bin/env python3 +# Copyright (c) 2023 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: MIT +# +import os +import warnings + +import torch +import torch.nn as nn + +warnings.filterwarnings("ignore") +import pdb +import pickle + +import clip +import torch.nn.functional as F +from PIL import Image +from torchvision import models + +import globvars as gv + + +# Vision and Language pretrained models. e.g., FLAVA model. +class SMART_VL_Net(nn.Module): + def __init__(self, args, VL_backbone): + super(SMART_VL_Net, self).__init__() + vocab_path = args.vocab_path + with open(vocab_path, "rb") as f: + self.vocab = pickle.load(f) + + self.num_opts = 5 + self.out_dim = args.feat_size # the intermediate feature size. + self.h_sz = 256 + self.feat_size = 768 + self.dummy_question = None + self.model_name = args.model_name + self.use_clip_text = args.use_clip_text + self.loss_type = args.loss_type + self.monolithic = args.monolithic + self.use_single_image_head = args.use_single_image_head + self.train_backbone = args.train_backbone + + if args.loss_type == "classifier" or args.loss_type == "puzzle_tails": + self.max_val = gv.MAX_VAL + 1 + elif args.loss_type == "regression": + self.max_val = 1 + + self.processor = args.preprocess + self.VL_backbone = VL_backbone + self.create_puzzle_head(args) + + self.q_MLP = nn.Sequential( + nn.Linear(self.feat_size, self.h_sz), + nn.ReLU(), + nn.Linear(self.h_sz, self.out_dim), + nn.ReLU(), + ) + + self.qv_MLP = nn.Sequential( + nn.Linear(self.feat_size, self.h_sz), + nn.ReLU(), + nn.Linear(self.h_sz, self.out_dim), + nn.ReLU(), + ) + + self.qv_fusion = nn.Sequential( + nn.Linear(self.out_dim * 2, self.out_dim), # for flava its *2. + nn.ReLU(), + nn.Linear(self.out_dim, self.out_dim), + nn.ReLU(), + ) + if self.monolithic: + self.qvo_fusion = nn.Sequential(nn.Linear(self.out_dim, self.max_val)) + else: + self.create_puzzle_tail(args) + + def create_puzzle_head(self, args): + if args.use_single_image_head: + self.im_encoder = nn.Sequential( + nn.Linear(self.feat_size, self.out_dim), nn.ReLU(), nn.Linear(self.out_dim, self.out_dim) + ) + else: + self.puzzle_ids = args.puzzle_ids + im_encoder = [nn.Sequential(nn.Linear(self.out_dim, 1))] + for i in range(1, gv.num_puzzles + 1): + im_encoder.append( + nn.Sequential( + nn.Linear(self.feat_size, self.out_dim), nn.ReLU(), nn.Linear(self.out_dim, self.out_dim) + ) + ) + self.im_encoder = nn.ModuleList(im_encoder) + + def create_puzzle_tail(self, args): + self.puzzle_ids = args.puzzle_ids + ans_decoder = [ + nn.Sequential(nn.Linear(self.out_dim, 1)) + ] # start with a dummy as we are 1-indexed wrt puzzle ids. + for pid in range(1, gv.num_puzzles + 1): + num_classes = gv.NUM_CLASSES_PER_PUZZLE[str(pid)] if args.loss_type == "classifier" else 1 + if int(pid) not in gv.SEQ_PUZZLES: + ans_decoder.append( + nn.Sequential( + nn.Linear(self.out_dim, self.out_dim), + nn.ReLU(), + nn.Linear(self.out_dim, self.out_dim), + nn.ReLU(), + nn.Linear(self.out_dim, num_classes), + ) + ) + else: + ans_decoder.append(nn.LSTM(self.out_dim, num_classes, num_layers=1, batch_first=True)) + self.ans_decoder = nn.ModuleList(ans_decoder) + + def process(self, images, text): + inputs = self.processor( + text=text, + images=images, + return_tensors="pt", + max_length=77, + padding=True, + return_codebook_pixels=True, + return_image_mask=True, + ) + inputs["input_ids_masked"] = inputs["input_ids"].detach().clone() + inputs["bool_masked_pos"] = torch.zeros_like(inputs["bool_masked_pos"]) + inputs = inputs.to("cuda") + return inputs + + def encode_image(self, im_feat, pids=None): + if self.use_single_image_head: + y = self.im_encoder(im_feat) + else: + y = torch.zeros(len(im_feat), im_feat.shape[1], self.out_dim).cuda() + for t in range(len(self.puzzle_ids)): + idx = pids == int(self.puzzle_ids[t]) + idx = idx.cuda() + if idx.sum() > 0: + y[idx] = F.relu(self.im_encoder[int(self.puzzle_ids[t])](im_feat[idx])) + return y + + def encode_image_and_text(self, qv_feat): + x = F.relu(self.qv_MLP(qv_feat)) + return x + + def encode_text(self, q_feat): + x = F.relu(self.q_MLP(q_feat)) + return x + + def decode_image(self, im_list): + """convert torch tensor images back to Image bcos VL FLAVA model works with images.""" + im_list = (im_list.permute(0, 2, 3, 1) * 255).cpu().numpy().astype("uint8") + im_list = [Image.fromarray(im_list[ii]) for ii in range(len(im_list))] # convert im + return im_list + + def decode_text(self, text): + tt = text.cpu() + text = [ + " ".join([self.vocab.idx2word[int(j)] for j in tt[i][1 : torch.nonzero(tt[i])[-1]]]) for i in range(len(tt)) + ] + return text + + def seq_decoder(self, decoder, feat): + """run the LSTM decoder sequentially for k steps""" + out = [None] * gv.MAX_DECODE_STEPS + hx = None + for k in range(gv.MAX_DECODE_STEPS): + try: + out[k], hx = decoder(feat, hx) + except: + pdb.set_trace() + return out + + def decode_individual_puzzles(self, feat, pids): + upids = torch.unique(pids) + out_feats = {} + for t in range(len(upids)): + idx = pids == upids[t] + key = str(upids[t].item()) + if upids[t] not in gv.SEQ_PUZZLES: + out_feats[key] = self.ans_decoder[int(key)](feat[idx]) + else: + out_feats[key] = self.seq_decoder(self.ans_decoder[int(key)], feat[idx]) + return out_feats + + def forward(self, im, q=None, puzzle_ids=None): + im = self.decode_image(im) + q_text = self.decode_text(q) + inputs = self.process(im, q_text) + if self.train_backbone: + outputs = self.VL_backbone(**inputs) + else: + with torch.no_grad(): + outputs = self.VL_backbone(**inputs) + + im_feat = outputs.image_embeddings # Batch size X (Number of image patches + 1) x Hidden size => 2 X 197 X 768 + q_feat = outputs.text_embeddings # Batch size X (Text sequence length + 1) X Hidden size => 2 X 77 X 768 + # qv_feat_mm = outputs.multimodal_embeddings # Batch size X (Number of image patches + Text Sequence Length + 3) X Hidden size => 2 X 275 x 768 + # Multimodal embeddings can be used for multimodal tasks such as VQA + + im_feat = self.encode_image(im_feat, puzzle_ids) + q_feat = self.encode_text(q_feat) + + qv_feat = self.qv_fusion(torch.cat([im_feat.mean(1), q_feat.mean(1)], dim=1)) + + if self.monolithic: + qv_feat = qv_feat.unsqueeze(1) + qvo_feat = self.qvo_fusion(qv_feat).squeeze() + else: + qvo_feat = self.decode_individual_puzzles(qv_feat, puzzle_ids) + + return qvo_feat + + +# Vision backbones and language backbones. +class SMART_Net(nn.Module): + def __init__(self, args, im_backbone=None): + super(SMART_Net, self).__init__() + vocab_path = args.vocab_path + with open(vocab_path, "rb") as f: + self.vocab = pickle.load(f) + + self.num_opts = 5 + self.out_dim = args.feat_size # 64 # + self.h_sz = 256 # 256 #128 # + self.dummy_question = None + self.model_name = args.model_name + self.use_clip_text = args.use_clip_text + self.loss_type = args.loss_type + self.monolithic = args.monolithic + self.use_single_image_head = args.use_single_image_head + self.train_backbone = args.train_backbone + self.word_embed = args.word_embed + + if args.loss_type == "classifier" or args.loss_type == "puzzle_tails": + self.max_val = gv.MAX_VAL + 1 + elif args.loss_type == "regression": + self.max_val = 1 + + # image backbones. + if args.model_name[:6] == "resnet": + self.im_feat_size = im_backbone.fc.weight.shape[1] + modules = list(im_backbone.children())[:-1] + self.im_cnn = nn.Sequential(*modules) + elif args.model_name in ["alexnet", "vgg"]: + im_backbone.classifier[-1] = nn.Identity() + self.im_cnn = im_backbone + self.im_encoder = nn.Linear(im_backbone.classifier[-3].weight.shape[1], self.out_dim) + elif args.model_name in ["swin_t"]: + self.im_feat_size = 768 + self.im_cnn = im_backbone + self.im_cnn.head = nn.Identity() + elif args.model_name in ["swin_b"]: + self.im_feat_size = 1024 + self.im_cnn = im_backbone + self.im_cnn.head = nn.Identity() + elif args.model_name in ["vit"]: + self.im_feat_size = 768 + self.im_cnn = im_backbone + self.im_cnn.heads.head = nn.Identity() + elif args.model_name in ["mae"]: + self.preprocess = args.preprocess + self.im_cnn = lambda x: self.process_MAE(x) # inputs = feature_extractor(images=image, return_tensors="pt") + self.im_backbone = im_backbone + self.im_feat_size = 768 + elif args.model_name in ["cross_transformer"]: # when using a vision transformer model. + from vit_pytorch.crossformer import CrossFormer + + self.im_cnn = CrossFormer( + num_classes=256, # number of output classes + dim=(64, 128, 256, 512), # dimension at each stage + depth=(2, 2, 8, 2), # depth of transformer at each stage + global_window_size=(8, 4, 2, 1), # global window sizes at each stage + local_window_size=7, # local window size (can be customized for each stage, but in paper, held constant at 7 for all stages) + ) + + self.im_feat_size = 256 + else: + raise "unknown model_name %s" % (args.model_name) + + self.create_puzzle_head(args) + + # language backbones + if self.use_clip_text: + self.q_encoder, _ = clip.load("ViT-B/32", device="cuda") + self.clip_dim = 512 + self.q_MLP = nn.Sequential( + nn.Linear(self.clip_dim, self.h_sz), nn.ReLU(), nn.Linear(self.h_sz, self.out_dim) + ) + else: + if args.word_embed == "standard": + self.q_emb = nn.Embedding(len(self.vocab), self.h_sz, max_norm=1) + self.q_lstm = nn.LSTM(self.h_sz, self.h_sz, num_layers=2, batch_first=True, bidirectional=True) + else: + word_dim = gv.word_dim + self.q_emb = nn.Identity() + self.q_lstm = nn.GRU(word_dim, self.h_sz, num_layers=1, batch_first=True, bidirectional=True) + self.q_MLP = nn.Linear(self.h_sz * 2, self.out_dim) + + self.o_encoder = nn.Sequential( + nn.Embedding(len(self.vocab), self.out_dim, max_norm=1), + nn.Linear(self.out_dim, self.out_dim), + nn.ReLU(), + ) + self.qv_fusion = nn.Sequential( + nn.Linear(self.out_dim * 2, self.out_dim), + nn.ReLU(), + nn.Linear(self.out_dim, self.out_dim), + nn.ReLU(), + ) + if self.monolithic: + self.qvo_fusion = nn.Sequential(nn.Linear(self.out_dim, self.max_val)) + else: + self.create_puzzle_tail(args) + + def process_MAE(self, x): + x = self.decode_image(x) # get from tensor to PIL images + inputs = self.preprocess(images=x, return_tensors="pt").to("cuda") + outputs = self.im_backbone(**inputs) + return outputs.last_hidden_state.mean(1) + + def create_puzzle_head(self, args): + if args.use_single_image_head: + self.im_encoder = nn.Sequential( + nn.Linear(self.im_feat_size, self.out_dim), nn.ReLU(), nn.Linear(self.out_dim, self.out_dim) + ) + else: + self.puzzle_ids = args.puzzle_ids + im_encoder = [nn.Sequential(nn.Linear(self.out_dim, 1))] + for i in range(1, gv.num_puzzles + 1): + im_encoder.append( + nn.Sequential( + nn.Linear(self.im_feat_size, self.out_dim), nn.ReLU(), nn.Linear(self.out_dim, self.out_dim) + ) + ) + self.im_encoder = nn.ModuleList(im_encoder) + + def create_puzzle_tail(self, args): + self.puzzle_ids = args.puzzle_ids + ans_decoder = [ + nn.Sequential(nn.Linear(self.out_dim, 1)) + ] # start with a dummy as we are 1-indexed wrt puzzle ids. + for pid in range(1, gv.num_puzzles + 1): # self.puzzle_ids: + num_classes = gv.NUM_CLASSES_PER_PUZZLE[str(pid)] if args.loss_type == "classifier" else 1 + if int(pid) not in gv.SEQ_PUZZLES: + ans_decoder.append( + nn.Sequential( + nn.Linear(self.out_dim, self.out_dim), + nn.ReLU(), + nn.Linear(self.out_dim, self.out_dim), + nn.ReLU(), + nn.Linear(self.out_dim, num_classes), + ) + ) + else: + ans_decoder.append(nn.LSTM(self.out_dim, num_classes, num_layers=1, batch_first=True)) + self.ans_decoder = nn.ModuleList(ans_decoder) + + def decode_image(self, im_list): + """convert torch tensor images back to Image bcos VL FLAVA model works with images.""" + # im_list = (im_list +1)/2. # this is in range [0, 1]. + im_list = (im_list.permute(0, 2, 3, 1) * 255).cpu().numpy().astype("uint8") + im_list = [Image.fromarray(im_list[ii]) for ii in range(len(im_list))] # convert im + return im_list + + def save_grad_hook(self): + self.vis_grad = None + + def bwd_hook(module, in_grad, out_grad): + self.vis_grad = out_grad + + return bwd_hook + + def save_fwd_hook(self): + self.vis_conv = None + + def fwd_hook(__, _, output): + self.vis_conv = output + + return fwd_hook + + def encode_image(self, im, pids=None): + if self.train_backbone: + x = self.im_cnn(im).squeeze() + else: + with torch.no_grad(): + x = self.im_cnn(im).squeeze() + + if len(x.shape) == 1: + x = x.unsqueeze(0) + + if self.use_single_image_head: + y = self.im_encoder(x) + else: + y = torch.zeros(len(im), self.out_dim).cuda() + for t in range(len(self.puzzle_ids)): + idx = pids == int(self.puzzle_ids[t]) + idx = idx.cuda() + if idx.sum() > 0: + y[idx] = F.relu(self.im_encoder[int(self.puzzle_ids[t])](x[idx])) + + return y + + def decode_text(self, text): + get_range = lambda x: range(1, x) if x < 70 else range(x - 70 + 4, x) + tt = text.cpu() + text = [ + " ".join([self.vocab.idx2word[int(j)] for j in tt[i][get_range(torch.nonzero(tt[i])[-1])]]) + for i in range(len(tt)) + ] + return text + + def encode_text(self, text): + if self.word_embed == "standard": + x = self.q_emb(text) + x, (h, _) = self.q_lstm(x.float()) + x = F.relu(self.q_MLP(x.mean(1))) + elif self.word_embed == "gpt" or "bert" or "glove": + text = self.decode_text(text) + q_enc = torch.zeros(len(text), gv.max_qlen, gv.word_dim).cuda() + for ii, tt in enumerate(text): + q_feat = gv.word_embed(tt) + q_enc[ii, : min(gv.max_qlen, len(q_feat)), :] = q_feat + x, (h, _) = self.q_lstm(q_enc.float()) + x = F.relu(self.q_MLP(x.mean(1))) + else: + x = gv.word_embed(text) + + return x + + def seq_decoder(self, decoder, feat): + """run the LSTM decoder sequentially for k steps""" + out = [None] * gv.MAX_DECODE_STEPS + hx = None + for k in range(gv.MAX_DECODE_STEPS): + try: + out[k], hx = decoder(feat, hx) + except: + pdb.set_trace() + return out + + def decode_individual_puzzles(self, feat, pids): + upids = torch.unique(pids) + out_feats = {} + for t in range(len(upids)): + idx = pids == upids[t] + key = str(upids[t].item()) + if upids[t] not in gv.SEQ_PUZZLES: + out_feats[key] = self.ans_decoder[int(key)](feat[idx]) + else: + out_feats[key] = self.seq_decoder(self.ans_decoder[int(key)], feat[idx]) + return out_feats + + def forward(self, im, q=None, puzzle_ids=None): + im_feat = self.encode_image(im, puzzle_ids) + q_feat = self.encode_text(q) + qv_feat = self.qv_fusion(torch.cat([im_feat, q_feat], dim=1)) + if self.monolithic: + qv_feat = qv_feat.unsqueeze(1) + qvo_feat = self.qvo_fusion(qv_feat).squeeze() + else: + qvo_feat = self.decode_individual_puzzles(qv_feat, puzzle_ids) + return qvo_feat + + +def load_pretrained_models(args, model_name, model=None): + + if args.test and model is not None: + model_path = os.path.join(args.location, "ckpt_%s_%s_%s.pth" % (args.model_name, args.word_embed, args.seed)) + print("test: loading checkpoint %s ..." % (model_path)) + checkpoint = torch.load(model_path) + model.load_state_dict(checkpoint["net"], strict=True) + return + + preprocess = None + if args.model_name in ["resnet18"]: + model = models.__dict__[args.model_name](pretrained=True) + elif args.model_name in ["resnet50"]: # use_resnet: + from torchvision.models import ResNet50_Weights, resnet50 + + weights = ResNet50_Weights.DEFAULT + model = resnet50(weights=weights) + preprocess = weights.transforms() + elif args.model_name == "swin_t": # use_vit: + from torchvision.models import Swin_T_Weights, swin_t + + weights = Swin_T_Weights.IMAGENET1K_V1 + model = swin_t(weights=weights) + preprocess = weights.transforms() + elif args.model_name == "swin_b": # use_vit: + from torchvision.models import Swin_B_Weights, swin_b + + weights = Swin_B_Weights.IMAGENET1K_V1 + model = swin_b(weights=weights) + preprocess = weights.transforms() + elif args.model_name == "vit": + from torchvision.models import ViT_B_16_Weights, vit_b_16 + + weights = ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1 # ViT_B_16_Weights.DEFAULT # + model = vit_b_16(weights=weights) + preprocess = weights.transforms() + elif args.model_name == "flava": + from transformers import FlavaForPreTraining, FlavaProcessor # FlavaModel, + + model = FlavaForPreTraining.from_pretrained("facebook/flava-full").eval() + preprocess = FlavaProcessor.from_pretrained("facebook/flava-full") + elif args.model_name == "clip": + model, preprocess = clip.load("ViT-B/32", device="cuda") + elif args.model_name == "mae": + from transformers import AutoFeatureExtractor, ViTMAEModel + + feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/vit-mae-base") + model = ViTMAEModel.from_pretrained("facebook/vit-mae-base") + preprocess = feature_extractor + + else: + print("model name is %s: not loading pre-trained model." % (args.model_name)) + + if args.pretrained: + if os.path.isfile(args.pretrained): + print("=> loading checkpoint '{}'".format(args.pretrained)) + checkpoint = torch.load(args.pretrained, map_location="cpu") + + # rename moco pre-trained keys + state_dict = checkpoint["state_dict"] + for k in list(state_dict.keys()): + # retain only encoder up to before the embedding layer + if k.startswith("module.encoder") and not k.startswith("module.encoder.fc"): + # remove prefix + state_dict[k[len("module.encoder.") :]] = state_dict[k] + # delete renamed or unused k + del state_dict[k] + + msg = model.load_state_dict(state_dict, strict=False) + assert set(msg.missing_keys) == {"fc.weight", "fc.bias"} + + print("=> loaded pre-trained model '{}'".format(args.pretrained)) + else: + print("=> no checkpoint found at '{}'".format(args.pretrained)) + return model, preprocess diff --git a/net_clip.py b/net_clip.py new file mode 100644 index 0000000..15d0224 --- /dev/null +++ b/net_clip.py @@ -0,0 +1,187 @@ +#!/usr/bin/env python3 +# Copyright (c) 2023 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: MIT +# + +import warnings + +import torch +import torch.nn as nn + +warnings.filterwarnings("ignore") +import pdb +import pickle + +import clip +import torch.nn.functional as F +from PIL import Image + +import globvars as gv + + +# Vision and Language pretrained models. +class SMART_VL_CLIP_Net(nn.Module): + def __init__(self, args, VL_backbone): + super(SMART_VL_CLIP_Net, self).__init__() + vocab_path = args.vocab_path + with open(vocab_path, "rb") as f: + self.vocab = pickle.load(f) + + self.num_opts = 5 + self.out_dim = args.feat_size + self.h_sz = 256 + self.feat_size = 512 + self.dummy_question = None + self.model_name = args.model_name + self.use_clip_text = args.use_clip_text + self.loss_type = args.loss_type + self.monolithic = args.monolithic + self.use_single_image_head = args.use_single_image_head + self.train_backbone = args.train_backbone + + if args.loss_type == "classifier" or args.loss_type == "puzzle_tails": + self.max_val = gv.MAX_VAL + 1 + elif args.loss_type == "regression": + self.max_val = 1 + + self.preprocess = args.preprocess + self.VL_backbone = VL_backbone + self.create_puzzle_head(args) + + self.q_MLP = nn.Sequential( + nn.Linear(self.feat_size, self.h_sz), + nn.ReLU(), + nn.Linear(self.h_sz, self.out_dim), + nn.ReLU(), + ) + + self.qv_fusion = nn.Sequential( + nn.Linear(self.out_dim * 2, self.out_dim), + nn.ReLU(), + nn.Linear(self.out_dim, self.out_dim), + nn.ReLU(), + ) + if self.monolithic: + self.qvo_fusion = nn.Sequential(nn.Linear(self.out_dim, self.max_val)) + else: + self.create_puzzle_tail(args) + + def create_puzzle_head(self, args): + if args.use_single_image_head: + self.im_encoder = nn.Sequential( + nn.Linear(self.feat_size, self.out_dim), nn.ReLU(), nn.Linear(self.out_dim, self.out_dim) + ) + else: + self.puzzle_ids = args.puzzle_ids + im_encoder = [nn.Sequential(nn.Linear(self.out_dim, 1))] + for i in range(1, gv.num_puzzles + 1): + im_encoder.append( + nn.Sequential( + nn.Linear(self.feat_size, self.out_dim), nn.ReLU(), nn.Linear(self.out_dim, self.out_dim) + ) + ) + self.im_encoder = nn.ModuleList(im_encoder) + + def create_puzzle_tail(self, args): + self.puzzle_ids = args.puzzle_ids + ans_decoder = [ + nn.Sequential(nn.Linear(self.out_dim, 1)) + ] # start with a dummy as we are 1-indexed wrt puzzle ids. + for pid in range(1, gv.num_puzzles + 1): + num_classes = gv.NUM_CLASSES_PER_PUZZLE[str(pid)] if args.loss_type == "classifier" else 1 + if int(pid) not in gv.SEQ_PUZZLES: + ans_decoder.append( + nn.Sequential( + nn.Linear(self.out_dim, self.out_dim), + nn.ReLU(), + nn.Linear(self.out_dim, self.out_dim), + nn.ReLU(), + nn.Linear(self.out_dim, num_classes), + ) + ) + else: + ans_decoder.append(nn.LSTM(self.out_dim, num_classes, num_layers=1, batch_first=True)) + self.ans_decoder = nn.ModuleList(ans_decoder) + + def process(self, im, q_text): + q_text = self.decode_text(q_text) + text = clip.tokenize(q_text, truncate=True).to("cuda") + return im, text + + def encode_image(self, im_feat, pids=None): + if self.use_single_image_head: + y = self.im_encoder(im_feat) + else: + y = torch.zeros(len(im_feat), self.out_dim).cuda() + for t in range(len(self.puzzle_ids)): + idx = pids == int(self.puzzle_ids[t]) + idx = idx.cuda() + if idx.sum() > 0: + y[idx] = F.relu(self.im_encoder[int(self.puzzle_ids[t])](im_feat[idx])) + return y + + def encode_text(self, q_feat): + x = F.relu(self.q_MLP(q_feat)) + return x + + def decode_image(self, im_list): + """convert torch tensor images back to Image bcos VL FLAVA model works with images.""" + im_list = (im_list.permute(0, 2, 3, 1) * 255).cpu().numpy().astype("uint8") + im_list = [Image.fromarray(im_list[ii]) for ii in range(len(im_list))] # convert im + return im_list + + def decode_text(self, text): + get_range = lambda x: range(1, x) if x < 70 else range(x - 70 + 4, x) + tt = text.cpu() + text = [ + " ".join([self.vocab.idx2word[int(j)] for j in tt[i][get_range(torch.nonzero(tt[i])[-1])]]) + for i in range(len(tt)) + ] + return text + + def seq_decoder(self, decoder, feat): + """run the LSTM decoder sequentially for k steps""" + out = [None] * gv.MAX_DECODE_STEPS + hx = None + for k in range(gv.MAX_DECODE_STEPS): + try: + out[k], hx = decoder(feat, hx) + except: + pdb.set_trace() + return out + + def decode_individual_puzzles(self, feat, pids): + upids = torch.unique(pids) + out_feats = {} + for t in range(len(upids)): + idx = pids == upids[t] + key = str(upids[t].item()) + if upids[t] not in gv.SEQ_PUZZLES: + out_feats[key] = self.ans_decoder[int(key)](feat[idx]) + else: + out_feats[key] = self.seq_decoder(self.ans_decoder[int(key)], feat[idx]) + return out_feats + + def forward(self, im, q=None, puzzle_ids=None): + im, text = self.process(im, q) + + if self.train_backbone: + im_feat = self.VL_backbone.encode_image(im) + q_feat = self.VL_backbone.encode_text(text) + else: + with torch.no_grad(): + im_feat = self.VL_backbone.encode_image(im) + q_feat = self.VL_backbone.encode_text(text) + + im_feat = self.encode_image(im_feat.float(), puzzle_ids) + q_feat = self.encode_text(q_feat.float()) + qv_feat = self.qv_fusion(torch.cat([im_feat, q_feat], dim=1)) + + if self.monolithic: + qv_feat = qv_feat.unsqueeze(1) + qvo_feat = self.qvo_fusion(qv_feat).squeeze() + else: + qvo_feat = self.decode_individual_puzzles(qv_feat, puzzle_ids) + + return qvo_feat diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..0b46a67 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,40 @@ +# Copyright (c) 2023 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: MIT +charset-normalizer==3.1.0 +click==8.1.3 +contourpy==1.0.7 +cycler==0.11.0 +einops==0.6.0 +filelock==3.10.7 +fonttools==4.39.3 +ftfy==6.1.1 +huggingface-hub==0.13.3 +idna==3.4 +importlib-resources==5.12.0 +joblib==1.2.0 +kiwisolver==1.4.4 +matplotlib==3.7.1 +nltk==3.8.1 +numpy==1.21.5 +opencv-python==4.7.0.72 +packaging==23.0 +pandas==1.4.4 +Pillow==9.5.0 +pyparsing==3.0.9 +python-dateutil==2.8.2 +pytz==2023.3 +PyYAML==6.0 +regex==2023.3.23 +requests==2.28.2 +six==1.16.0 +tokenizers==0.13.2 +torch==1.13.1 +torchvision==0.14.1 +tqdm==4.65.0 +transformers==4.26.0 +typing_extensions==4.5.0 +urllib3==1.26.15 +vit-pytorch==1.2.0 +wcwidth==0.2.6 +zipp==3.15.0 diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..a299910 --- /dev/null +++ b/utils.py @@ -0,0 +1,437 @@ +#!/usr/bin/env python3 +# Copyright (c) 2023 Mitsubishi Electric Research Laboratories (MERL) +# +# SPDX-License-Identifier: MIT +# +import json +import os +import os.path as osp +import pdb +import pickle as pkl +import sys + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from PIL import Image + +import globvars as gv + + +def fix_acc(acc_list): + """removes accuracy for puzzles in gv.puzzles_not_included""" + idx = np.array(list(set(np.arange(1, gv.num_puzzles + 1)).difference(set(gv.puzzles_not_included)))) + new_acc_list = acc_list[idx - 1] + return new_acc_list + + +def get_icon_dataset_classes(icon_path): + """returns the classes in ICONs-50 dataset""" + with open(icon_path, "r") as f: + icon_classes = f.readlines() + return [ii.rstrip() for ii in icon_classes] + + +def print_puzz_acc(args, puzz_acc, log=True): + to_int = lambda x: np.array(list(x)).astype("int") + cls_mean = lambda x, idx, pids: np.array([x[int(ii)] for ii in idx]).sum() / len( + set(to_int(idx)).intersection(set(to_int(pids))) + ) + acc_list = np.zeros( + gv.num_puzzles + 1, + ) + opt_acc_list = np.zeros( + gv.num_puzzles + 1, + ) + + if not os.path.exists(os.path.join(args.save_root, "results/%d/" % (gv.seed))): + os.makedirs(os.path.join(args.save_root, "results/%d/" % (gv.seed))) + + if len(puzz_acc.keys()) > 10: + for k, key in enumerate(puzz_acc.keys()): + acc = 100.0 * puzz_acc[key][0] / puzz_acc[key][2] + oacc = 100.0 * puzz_acc[key][1] / puzz_acc[key][2] + acc_list[int(key)] = acc + opt_acc_list[int(key)] = oacc + if log: + for t in range(1, gv.num_puzzles + 1): + print("%d opt_acc=%0.2f acc=%0.2f" % (t, opt_acc_list[t], acc_list[t]), end="\t") + if t % 5 == 0: + print("\n") + print("\n\n") + + puzzles = read_dataset_info(gv.SMART_DATASET_INFO_FILE) + class_avg_perf = {} + classes = ["counting", "math", "logic", "path", "algebra", "measure", "spatial", "pattern"] + print(classes) + for kk in classes: + idx_list = puzzles[kk] + class_avg_perf[kk] = ( + cls_mean(acc_list, idx_list, list(puzz_acc.keys())), + cls_mean(opt_acc_list, idx_list, list(puzz_acc.keys())), + ) + print("%0.1f/%0.1f & " % (class_avg_perf[kk][0], class_avg_perf[kk][1]), end=" ") + print("\n\n") + + fig = plt.figure(figsize=(30, 4)) + ax = plt.gca() + ax.bar(np.arange(1, gv.num_actual_puzz), fix_acc(acc_list[1:])) + ax.set_xticks(np.arange(1, gv.num_actual_puzz)) + ax.set_xlabel("puzzle ids", fontsize=16) + ax.set_ylabel("$O_{acc}$ %", fontsize=20) + fig.tight_layout() + plt.savefig(os.path.join(args.save_root, "results/%d/acc_perf_scores_1.png" % (gv.seed))) + plt.close() + + fig = plt.figure(figsize=(30, 4)) + ax = plt.gca() + ax.bar(np.arange(1, gv.num_actual_puzz), fix_acc(opt_acc_list[1:])) + ax.set_xticks(np.arange(1, gv.num_actual_puzz)) # , [str(i) for i in np.arange(1,num_puzzles+1)]) + ax.set_xlabel("puzzle ids", fontsize=16) + ax.set_ylabel("$S_{acc}$ %", fontsize=20) + fig.tight_layout() + plt.savefig(os.path.join(args.save_root, "results/%d/opt_acc_perf_scores_1.png" % (gv.seed))) + plt.close() + else: + for key in puzz_acc.keys(): + acc = 100.0 * puzz_acc[key][0] / puzz_acc[key][2] + opt_acc = 100.0 * puzz_acc[key][1] / puzz_acc[key][2] + if log: + print("%s opt_acc=%0.2f acc=%0.2f" % (key, opt_acc, acc)) + acc_list[int(key)] = acc + opt_acc_list[int(key)] = opt_acc + + plt.figure() + plt.bar(np.arange(gv.num_puzzles + 1), acc_list) + plt.savefig(os.path.join(args.save_root, "results/%d/acc_perf_scores.png" % (gv.seed))) + plt.close() + plt.figure() + plt.bar(np.arange(gv.num_puzzles + 1), opt_acc_list) + plt.savefig(os.path.join(args.save_root, "results/%d/opt_acc_perf_scores.png" % (gv.seed))) + plt.close() + + +def get_option_sel_acc(pred_ans, opts, answer, answer_values, pid): + """converts a predicted answer to one of the given multiple choice options. + opts is b x num_options matrix""" + + def get_op_str(ii): + return gv.signs[int(str(ii)[0]) - 1] + str(ii)[1:] if ii >= 10 else gv.signs[0] + str(ii) + + if pid in gv.SEQ_PUZZLES: + result = np.abs(answer_values - pred_ans).sum(axis=1) == 0 + elif pid in [32, 69, 82, 84, 95, 98, 51, 66, 44, 68]: + result = [pred_ans[i] == answer[i] for i in range(len(pred_ans))] + else: + try: + result = ( + np.abs(opts.astype("float") - pred_ans.unsqueeze(1).cpu().numpy()).argmin(axis=1) + == answer.cpu().numpy() + ) + except: + result = [pred_ans[i] == answer[i] for i in range(len(pred_ans))] + print("error!!") + pdb.set_trace() + return np.array(result) + + +def read_dataset_info(csvfilename): + import csv + + qa_info = {} + with open(csvfilename, newline="") as csvfile: + datareader = csv.DictReader(csvfile) + for row in datareader: + key = str(row["type"]).lower() + if key not in qa_info.keys(): + qa_info[key] = [row["id"]] + else: + qa_info[key].append(row["id"]) + assert np.array([len(qa_info[key]) for key in qa_info.keys()]).sum() == 101 + return qa_info + + +def read_csv(csvfilename, puzzle_id): + import csv + + qa_info = [] + with open(csvfilename, newline="") as csvfile: + datareader = csv.DictReader(csvfile) + for row in datareader: + row["puzzle_id"] = str(puzzle_id) + if len(row["A"]) == 0: + row["A"] = "A" + row["B"] = "B" + row["C"] = "C" + row["D"] = "D" + row["E"] = "E" + qa_info.append(row) + return qa_info + + +def pad_with_max_val(gt_list, val): + """if the number of elements in gt is less than MAX_DECODE_STEPS, we pad it with the max value in a class""" + if len(gt_list) < gv.MAX_DECODE_STEPS: + gt_list = ( + gt_list + + ( + np.ones( + gv.MAX_DECODE_STEPS - len(gt_list), + ) + * val + ).tolist() + ) + return gt_list + + +def str_replace(ans): + ans = ans.replace(" hours", "") + ans = ans.replace(" hour", "").replace(" cm", "") + ans = ans.replace(" km", "") + return ans + + +def str_replace_(info, ans_opt): + ans = info[ans_opt] + ans = ans.replace(" hours", "") + ans = ans.replace(" hour", "").replace(" cm", "") + ans = ans.replace(" km", "") + ans = ans.replace("Impossible", "0") + info[ans_opt] = ans + return ans + + +def get_val(qinfo, ans_opt, is_one_of_option=False): + """get the value of the answer option. This code also encodes the value into a number by removing extreneous strings""" + """ is_one_of_option is True, when ans_opt is one of the options, need not be the correct answer option.""" + where = lambda x, y: np.where(np.array(x) == y)[0][0] + + pid = int(qinfo["puzzle_id"]) + if pid in gv.SEQ_PUZZLES: + ans = qinfo[ans_opt] + if pid == 16: + ans_opt_val = [int(ii) for ii in ans.replace("and", ",").replace(", ,", ",").replace(" ", "").split(",")] + ans_opt_val = pad_with_max_val(ans_opt_val, 26) + elif pid == 18: + ans_opt_val = [int(ii) for ii in ans.split("-")] + ans_opt_val = pad_with_max_val(ans_opt_val, 5) + elif pid == 35: + ans_opt_val = [ + ord(ii) - ord("A") for ii in ans.replace("and", ",").replace(", ,", ",").replace(" ", "").split(",") + ] + ans_opt_val = pad_with_max_val(ans_opt_val, 5) + elif pid == 39: + ans_opt_val = [ord(ii) - ord("A") for ii in list(ans)] + ans_opt_val = pad_with_max_val(ans_opt_val, 26) + elif pid == 63: + ans_opt_val = [ + int(ii) + for ii in ans.replace("and", ",") + .replace("or", ",") + .replace(", ,", ",") + .replace("only", "") + .replace(" ", "") + .split(",") + ] + key = str(63) + if key in gv.NUM_CLASSES_PER_PUZZLE: + ans_opt_val = pad_with_max_val(ans_opt_val, gv.NUM_CLASSES_PER_PUZZLE[key] - 1) + elif pid == 100: + ans_opt_val = [ord(ii) - ord("A") for ii in list(ans)] + ans_opt_val = pad_with_max_val(ans_opt_val, 26) + ans_opt_val = np.array(ans_opt_val) + + elif pid == 58: + # puzzle 58 has answers as , e.g./4,-5, etc. + # we use +=1, -=2, x=3, /=4. so /4 will be 44, -5=25, +2= 2. + ans_opt_val = qinfo[ans_opt] + ans_opt_val = (where(gv.signs, ans_opt_val[0]) + 1) * 10 + int(ans_opt_val[1:]) + elif pid == 25: + # we need to fix the time in AM/PM format properly. + ans = qinfo[ans_opt] + ans_opt_val = int(ans.replace(":00 AM", "").replace(":00 PM", "")) + if ans.find("PM") > -1: + ans_opt_val += 12 + else: + try: + ans_opt_val = int(qinfo[ans_opt]) + except: + if len(qinfo[ans_opt]) > 0: + try: + ans_opt_val = ord(qinfo[ans_opt]) - ord("A") + except: + try: + ans_opt_val = str_replace(qinfo[ans_opt]) + ans_opt_val = ans_opt_val.replace("Impossible", "0") # puzzle 58. + if int(qinfo["puzzle_id"]) == 1: # if the puzzle id is 1, then the options are icon classes. + ans_opt_val = "_".join(ans_opt_val.split(" ")) + if ans_opt_val in gv.icon_class_ids: + ans_opt_val = where(gv.icon_class_ids, ans_opt_val) + elif ans_opt_val + "s" in gv.icon_class_ids: + ans_opt_val = where(gv.icon_class_ids, ans_opt_val + "s") + ans_opt_val = int(ans_opt_val) + except: + print(qinfo) + pdb.set_trace() + else: + ans_opt_val = ord(ans_opt) - ord("A") + if not is_one_of_option: # implies we are encoding the correct answer. + qinfo["AnswerValue"] = ans_opt_val + return ans_opt_val + + +def get_puzzle_class_info(args): + # global SEQ_PUZZLES, puzzle_diff_str, puzzle_diff + puzzle_classes = {} + for puzzle_id in args.puzzle_ids: + puzzle_root = puzzle_id + "/" + gv.puzzle_diff_str[args.train_diff] + "/" + csv_file = "puzzle_%s%s.csv" % (puzzle_id, gv.puzzle_diff[args.train_diff]) + qa_info = read_csv(os.path.join(args.data_root, puzzle_root, csv_file), puzzle_id) + + pid = int(puzzle_id) + if pid not in gv.SEQ_PUZZLES: + num_classes = np.array([get_val(qa, qa["Answer"]) for qa in qa_info]).max() + 1 + else: + if pid in [16, 39, 100]: + num_classes = 26 + 1 # if the output is a string of numbers, and the max classes is - max val. + elif pid in [18, 35]: + num_classes = 5 + 1 # the minus one is for end of items. + elif pid in [63]: + num_classes = np.array([get_val(qa, qa["Answer"]).max() for qa in qa_info]).max() + 1 + puzzle_classes[str(puzzle_id)] = num_classes + return puzzle_classes + + +class Logger(object): + def __init__(self, log_file): + self.terminal = sys.stdout + self.log = open(log_file, "a") + + def write(self, message): + self.terminal.write(message) + self.log.write(message) + + def flush(self): + # this flush method is needed for python 3 compatibility. + # this handles the flush command by doing nothing. + # you might want to specify some extra behavior here. + pass + + +def set_gpu_devices(gpu_id): + gpu = "" + if gpu_id != -1: + gpu = str(gpu_id) + os.environ["CUDA_VOSIBLE_DEVICES"] = gpu + + +def load_file(filename): + """ + load obj from filename + :param filename: + :return: + """ + cont = None + if not osp.exists(filename): + print("{} not exist".format(filename)) + return cont + if osp.splitext(filename)[-1] == ".csv": + # return pd.read_csv(filename, delimiter= '\t', index_col=0) + return pd.read_csv(filename, delimiter=",") + with open(filename, "r") as fp: + if osp.splitext(filename)[1] == ".txt": + cont = fp.readlines() + cont = [c.rstrip("\n") for c in cont] + elif osp.splitext(filename)[1] == ".json": + cont = json.load(fp) + return cont + + +def save_file(obj, filename): + """ + save obj to filename + :param obj: + :param filename: + :return: + """ + filepath = osp.dirname(filename) + if filepath != "" and not osp.exists(filepath): + os.makedirs(filepath) + else: + with open(filename, "w") as fp: + json.dump(obj, fp, indent=4) + + +def pkload(file): + data = None + if osp.exists(file) and osp.getsize(file) > 0: + with open(file, "rb") as fp: + data = pkl.load(fp) + # print('{} does not exist'.format(file)) + return data + + +def get_image(img): + img = (img - img.min()) / (img.max() - img.min() + 1e-10) + img = img * 255 + img = img.cpu().numpy() + img = img.astype("uint8") + return Image.fromarray(img) + + +def pkdump(data, file): + dirname = osp.dirname(file) + if not osp.exists(dirname): + os.makedirs(dirname) + with open(file, "wb") as fp: + pkl.dump(data, fp) + + +def get_puzzle_ids(args): + puzzles = read_dataset_info(gv.SMART_DATASET_INFO_FILE) + if args.puzzles == "all": + puzzle_ids = os.listdir(args.data_root) + puzzle_ids = np.array(puzzle_ids)[np.array([x.find(".") == -1 for x in puzzle_ids])] + puzzle_ids = puzzle_ids.tolist() + puzzle_ids_str = "all" + elif args.puzzles in puzzles: + puzzle_ids = puzzles[args.puzzles] + puzzle_ids_str = args.puzzles + else: + puzzle_ids = args.puzzle.split(",") + puzzle_ids_str = "_".join(puzzle_ids) + + if args.monolithic: + # remove sequential puzzles from the monolithic architecture. + puzzle_ids = set(puzzle_ids).difference(set([str(ii) for ii in gv.SEQ_PUZZLES])) + puzzle_ids = list(puzzle_ids) + puzzle_ids_str = puzzle_ids_str + "_monolithic" + + return puzzle_ids_str, puzzle_ids + + +def backup_code_and_start_logger(args, log_path, seed): + test = "test" if args.test else "" + log_path = os.path.join(log_path, str(seed), test) + if os.path.exists(log_path): + log_path += "." + str(np.random.randint(0, high=100)) + print("test_path = %s" % (log_path)) + if not os.path.exists(log_path): + os.makedirs(log_path) + + if not args.test: + code_path = os.path.join(log_path, "code") + if not os.path.exists(code_path): + os.mkdir(code_path) + print("saving code to %s" % (code_path)) + os.system("cp *.py %s" % (code_path)) + + with open("%s/cmd_line.txt" % (log_path), "w") as cmd: + cmd.write(str(sys.argv)) + + log_file = os.path.join(log_path, "%d.log" % (seed)) + sys.stdout = Logger(log_file) + print("logging results to %s" % (log_file)) + + +#