Skip to content

Latest commit

 

History

History
166 lines (139 loc) · 5.81 KB

README.md

File metadata and controls

166 lines (139 loc) · 5.81 KB

In-Context Retrieval-Augmented Language Models

This repo contains the code for reproducing the experiments on WikiText-103 from AI21 Labs' paper In-Context Retrieval-Augmented Language Models (In-Context RALM), to appear in the Transactions of the Association for Computational Linguistics (TACL).

Our code is mainly based on the Transformers and Pyserini libraries.
We test it on Python 3.8.

Table of Contents

Setup

To install the required libraries in our repo, run:

pip install -r requirements.txt

To have a Pytorch version specific to your CUDA, install your version before running the above command.

Retrieval

BM25

Our BM25 preparation script works with Pyserini, so Java 11 is required - see their installation guide.
If you have Java 11 installed, make sure your JAVA_HOME environment variable is set to the correct path. On a Linux system, the correct path might look something like /usr/lib/jvm/java-11.
Then run:

python prepare_retrieval_data.py \
--retrieval_type sparse \
--tokenizer_name $MODEL_NAME \
--max_length 1024 \
--dataset_path wikitext \
--dataset_name wikitext-103-v1 \
--dataset_split [validation, test] \
--index_name wikipedia-dpr \
--forbidden_titles_path ralm/retrievers/wikitext103_forbidden_titles.txt \
--stride 4 \
--output_file $RETRIEVAL_FILE \
--num_tokens_for_query 32 \
--num_docs 16 

Evaluation

List of Language Models

In the paper, we give the results for the following models (replace $MODEL_NAME with one of those).
Note that the larger models may need model parallelism (on a 40GB A100, we used model parallelism for OPT-30B and OPT-66B).
See details below on how to apply this option.

  • GPT-2: gpt2, gpt2-medium, gpt2-large, gpt2-xl
  • GPT-Neo: EleutherAI/gpt-neo-1.3B, EleutherAI/gpt-neo-2.7B, EleutherAI/gpt-j-6B
  • OPT: facebook/opt-125m, facebook/opt-350m, facebook/opt-1.3b, facebook/opt-2.7b, facebook/opt-6.7b, facebook/opt-13b, facebook/opt-30b, facebook/opt-66b

Evaluate models w/o retrieval

To run evaluation on models without retrieval, please use the following command (you can increase stride to 32 for faster evaluation):

python eval_lm.py \
--model_name $MODEL_NAME \
--dataset_path wikitext \
--dataset_name wikitext-103-v1 \
--dataset_split [validation, test] \
--output_dir $OUTPUT_DIR \
--stride 4 \
--max_length 1024 \
[--model_parallelism]

Evaluate models with retrieval:

To run models with retrieval, use the $RETRIEVAL_FILE output from the prepare_retrieval_data.py script:

python eval_lm.py \
--model_name $MODEL_NAME \
--dataset_path wikitext \
--dataset_name wikitext-103-v1 \
--dataset_split [validation, test] \
--output_dir $OUTPUT_DIR \
--stride 4 \
--max_length 1024 \
[--model_parallelism] \
--retrieved_file $RETRIEVAL_FILE

Note: Our main retrieval flow assumes you want to use the top-scored passage from your retrieval file (--ranking_strategy first).

Reranking

Currently, we support logprob (the zero-shot method described in subsection 6.1) and oracle (to understand the potential gains from reranking).

For reranking, first you need to make sure you run the retrieval script with num_docs=16 (or any other number you want to rerank on). If you enable multiple GPUs, data parallelism will automatically be applied (each GPU will get different retrieved documents to condition on). Then run:

python eval_lm.py \
--model_name $MODEL_NAME \
--dataset_path wikitext \
--dataset_name wikitext-103-v1 \
--dataset_split [validation, test] \
--output_dir $OUTPUT_DIR \
--stride 4 \
--max_length 1024 \
[--model_parallelism] \
--retrieved_file $RETRIEVAL_FILE \
--ranking_strategy [logprob, oracle] \
--num_docs_to_rank 16 \
--ranking_logprob_past_tokens 16

Question Answering Experiments

To run our QA experiments on Natural Questions, start by downloading the datasets augmented by DPR results:

wget https://dl.fbaipublicfiles.com/dpr/data/retriever_results/single/nq-test.json.gz
gzip -d ./nq-test.json.gz

To run our QA experiments on TriviaQA, install gsutil and copy the DPR-augmented dataset:

gsutil cp gs://ai21-publishing-public-models/in-context-ralm/trivia-test-dpr-results.json ./trivia-test.json

Then run the evaluation script:

python eval_qa.py \
--model_name $MODEL_NAME \
--dataset_path [nq-test.json,trivia-test.json] \
--output_dir $OUTPUT_DIR \
--num_docs [0,1,2] \
[--model_parallelism]

where num_docs is the number of retrieved documents to include in-context (num_docs=0 is the closed-book setting, num_docs>=1 is open-book setting.)

Citation

If you find our paper or code helpful, please cite our paper:

@article{ram-etal-2023-context,
    title = "In-Context Retrieval-Augmented Language Models",
    author = "Ram, Ori  and
      Levine, Yoav  and
      Dalmedigos, Itay  and
      Muhlgay, Dor  and
      Shashua, Amnon  and
      Leyton-Brown, Kevin  and
      Shoham, Yoav",
    journal = "Transactions of the Association for Computational Linguistics",
    volume = "11",
    year = "2023",
    address = "Cambridge, MA",
    publisher = "MIT Press",
    url = "https://aclanthology.org/2023.tacl-1.75",
    doi = "10.1162/tacl_a_00605",
    pages = "1316--1331",
}