-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #67 from Maitreyapatel/examples
Examples of recheck usecases
- Loading branch information
Showing
13 changed files
with
218 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# Use any new dataset | ||
|
||
To add a new dataset, we just need to define the configuration files for the dataset and modify the configuration file for the model to adapt it for the given dataset. | ||
|
||
Note: To learn on how to define the new dataset config, please refer to the documentation here. | ||
|
||
## Install `reliability-checklist` | ||
|
||
```bash | ||
pip install git+https://github.com/Maitreyapatel/reliability-checklist | ||
|
||
python -m spacy download en_core_web_sm | ||
python -c "import nltk;nltk.download('wordnet')" | ||
``` | ||
|
||
## Steps to follow | ||
|
||
Goal: Use the `snli` dataset instead of `mnli` dataset from existing recipes. | ||
|
||
As we already have models pre-defined for `mnli` task, we do not need to explicitly define them again. | ||
|
||
1. Create the configs folder to store the any new configurations. | ||
|
||
```bash | ||
mkdir -p configs/datamodule | ||
``` | ||
|
||
2. Put your `snli.yaml` inside the `./configs/datamodule` folder. | ||
|
||
3. Run following command to evalaute your model on this new dataset: | ||
|
||
```bash | ||
recheck task=mnli datamodule=snli 'hydra.searchpath=[file://./configs/]' | ||
``` | ||
|
||
**Note:** This will throw following error: `KeyError: "Column hypothesis_parse not in the dataset. Current columns in the dataset: ['premise', 'hypothesis', 'label', 'primary_key']"` | ||
|
||
4. Let's remove the INV_PASS augmentation because it requires the `hypothesis_parse` from input dataset. | ||
|
||
```bash | ||
recheck task=mnli datamodule=snli ~augmentation.inv_augmentation 'hydra.searchpath=[file://./configs/]' | ||
``` | ||
|
||
**Note:** This will again throw following the error: `KeyError: -1` Because of the dataset inconsistency on HuggingFace Space. But this is how you can add new dataset support. But make sure that your dataset is clean and should not have such inconsistencies as `snli` from huggingface. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
_target_: reliability_checklist.datamodules.common_datamodule.GeneralDataModule | ||
data_dir: ${paths.data_dir} | ||
batch_size: 1 | ||
num_workers: 0 | ||
pin_memory: False | ||
tokenizer_data: ${custom_model.tokenizer} | ||
model_type: ${custom_model.model_type} | ||
data_processing: ${custom_model.data_processing} | ||
dataset_specific_args: | ||
label_conversion: null | ||
label2id: | ||
0: "entailment" | ||
1: "neutral" | ||
2: "contradiction" | ||
cols: ["premise", "hypothesis"] | ||
name: snli | ||
split: test | ||
remove_cols: [] | ||
label_col: "label" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
# Define new evaluation metric | ||
|
||
This is an important situation, what if we want to add a new metric but we don't understand the `reliability-checklist` infrastructure? | ||
|
||
**Good news:** You don't need to have perfect knoweldge of the `reliability-checklist`. In this example, we will learn how to add new evaluation metric by following the `reliability-checklist` standards. | ||
|
||
In this example, let's assume that we want to add the F1 Score in our evaluation pipeline. Now follow below steps to implement this feature: | ||
|
||
1. Let's first create the necessary folders. | ||
|
||
```bash | ||
mkdir -p ./configs/callbacks | ||
mkdir ./src/ | ||
``` | ||
|
||
2. Create the new metric class inside the `./src/new_metric.py` which inherites the `reliability-checklist` wrapper `MonitorBasedMetric`. Now, you need to define at least three functions which contains the main logic of your code. | ||
|
||
- `init_logic` to define the variables to store the information in the from of dict. | ||
- `batch_logic` define the batch logic that returns the dict having values to the same variables defined above. | ||
- `end_logic` to get the final evaluation number from the `saved` input dictionary. This returns the two variables `results` and `extra`. `results` should be another dict containing the `metric_name: value`. While, you have a choice to store any extra variables/data from `extra` variable. | ||
- `save_logic` can be used to store custom results or plots. | ||
|
||
3. Define the `f1score.yaml` inside the `./configs/callbacks`. | ||
|
||
4. Define the `new_metric.yaml` inside the `./configs/callbacks` to load the default callbacks and f1score. | ||
|
||
5. Now add this directory into python search path and initiate the inference by defining `callbacks=new_eval`: | ||
|
||
```bash | ||
export PYTHONPATH="${PYTHONPATH}:${pwd}" | ||
recheck task=mnli trainer=gpu callbacks=new_eval augmentation=default 'hydra.searchpath=[file://./configs/]' +trainer.limit_test_batches=10 | ||
``` | ||
|
||
Note: above example only performs inference on 10 data instances without any augmentation. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
f1score_metric: | ||
_target_: src.new_metric.F1Metric | ||
monitor: "all" | ||
results_dir: ${paths.result_dir} ## do not change | ||
override: null | ||
radar: ${callbacks.radar_data} ## do not change | ||
max_possible: 1.0 ## define the maximum possible value | ||
inverse: false | ||
average: weighted ## newly defined variable according to f1score: 'macro', 'micro', 'weighted', null |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
defaults: | ||
- general_evals.yaml ## keep this as it is to make sure we use other default evals | ||
- f1score.yaml ## your new eval file name |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import numpy as np | ||
from sklearn.metrics import f1_score | ||
|
||
from reliability_checklist.callbacks.evals.discriminative import MonitorBasedMetric | ||
|
||
|
||
class F1Metric(MonitorBasedMetric): | ||
def __init__( | ||
self, | ||
monitor="all", | ||
name="f1score", | ||
results_dir="", | ||
override=None, | ||
radar=True, | ||
max_possible=1.0, | ||
inverse=False, | ||
average=None, | ||
): | ||
super().__init__(monitor, name, results_dir, override, radar, max_possible, inverse) | ||
self.average = average | ||
|
||
def init_logic(self) -> dict: | ||
return {"y_pred": [], "y_true": []} | ||
|
||
def batch_logic(self, outputs, batch): | ||
result = { | ||
"y_true": outputs["p2u_outputs"]["p2u"]["labels"].cpu().numpy(), | ||
"y_pred": np.argmax(outputs["p2u_outputs"]["logits"].cpu().numpy(), axis=1), | ||
} | ||
return result | ||
|
||
def end_logic(self, saved) -> dict: | ||
result = {"f1score": f1_score(saved["y_true"], saved["y_pred"], average=self.average)} | ||
extra = None | ||
return result, extra | ||
|
||
def save_logic(self, monitor, trainer, result, extra) -> None: | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
# Use any pre-trained model | ||
|
||
To add a new model for evaluation, we just need to define the configuration files for this model. | ||
|
||
Note: To learn on how to define the new model config, please refer to the documentation here. | ||
|
||
## Install `reliability-checklist` | ||
|
||
```bash | ||
pip install git+https://github.com/Maitreyapatel/reliability-checklist | ||
|
||
python -m spacy download en_core_web_sm | ||
python -c "import nltk;nltk.download('wordnet')" | ||
``` | ||
|
||
## Steps to follow | ||
|
||
Assuming that you know how to create the new model config and suppose that we have created a new config file `distillbert_base.yaml` for the existing `mnli` dataset by following the steps from here. | ||
|
||
1. Create the configs folder to store the any new configurations. | ||
|
||
```bash | ||
mkdir -p configs/custom_models | ||
``` | ||
|
||
2. Put your `distillbert_base.yaml` inside the `./configs/custom_models` folder. | ||
|
||
3. Copy existing parrot augmentation file inside `data` folder to save the time. | ||
|
||
4. Run following command to evalaute your model on this new dataset: | ||
|
||
```bash | ||
recheck task=mnli custom_model=distillbert_base 'hydra.searchpath=[file://./configs/]' | ||
``` |
33 changes: 33 additions & 0 deletions
33
examples/add_model/configs/custom_model/distillbert_base.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
## below parameters are self-explanatory | ||
model_name: "typeform/distilbert-base-uncased-mnli" | ||
model_type: "discriminative" ## you can only have following types: "encode-decode","decoder-only","bart","discriminative","shared","hybrid","t5", | ||
huggingface_class: null | ||
decoder_model_name: null | ||
model_path: null ## provide the path if you have custom trained model using transformers library | ||
tie_embeddings: false | ||
label: null | ||
tie_encoder_decoder: false | ||
pipeline: null | ||
|
||
additional_model_inputs: null ## specify the additional pre-defined input to model like bean_search for generative models | ||
|
||
tokenizer: | ||
model_name: ${..model_name} ## only specify the name from huggingface if it's different than the actual model | ||
label2id: ## this will vary based on the evaluation data, please refer to the your selected dataset config | ||
contradiction: 2 | ||
neutral: 1 | ||
entailment: 0 | ||
args: | ||
truncation: true | ||
padding: "max_length" | ||
|
||
## use following dataset pre-processing steps | ||
data_processing: | ||
header: null ## prompt header for input data? | ||
footer: null ## prompt header for signling output? | ||
separator: " [SEP] " ## what is separator token? leave `null` for generative models | ||
columns: | ||
null | ||
## you should define this only for generative or for prompt eng. models as shown below | ||
# premise: null | ||
# hypothesis: null |
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters