Skip to content

Commit

Permalink
Merge pull request #3 from pln-fing-udelar/simplify-commands
Browse files Browse the repository at this point in the history
Add scripts to preprocess europarl corpus and generate alignments
  • Loading branch information
bryant1410 authored Apr 26, 2022
2 parents 2aa9d2c + 98fda20 commit 840cf3f
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 73 deletions.
48 changes: 6 additions & 42 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,30 +34,10 @@ Alternatively, follow these steps to train it yourself.

### Prepare the Corpus

1. Download the [Europarl Spanish-English parallel corpus](https://www.statmt.org/europarl/v7/es-en.tgz).
2. Remove the sentences that don't form a pair (the sentences that correspond with an empty line).
3. Remove sentences of length 1.
4. Remove sentences that contain tags (characters "<" and ">").
5. Split the corpus into test, train and val. A good size can be 2000 sentences for test and 2000 for val. Name the
files `corpus.es`, `validation.es`, `test.es`, `corpus.en`, `validation.en` and `test.en`.
7. Run the following commands, to learn the vocabulary, tokenize the sentences, and shuffle the corpus:
Run the following script to download the [Europarl Spanish-English parallel corpus](https://www.statmt.org/europarl/v7/es-en.tgz), do some preprocessing, learn the vocabulary, tokenize the sentences, and split the corpus:

```bash
spm_train --input=corpus.en --model_prefix=en --vocab_size=32000 --character_coverage=1.0 --model_type=unigram
spm_train --input=corpus.es --model_prefix=es --vocab_size=32000 --character_coverage=1.0 --model_type=unigram
spm_encode --model=en.model --output_format=piece < corpus.en > corpus.32k.en
spm_encode --model=en.model --output_format=piece < validation.en > validation.32k.en
spm_encode --model=en.model --output_format=piece < test.en > test.32k.en
spm_encode --model=es.model --output_format=piece < corpus.es > corpus.32k.es
spm_encode --model=es.model --output_format=piece < validation.es > validation.32k.es
spm_encode --model=es.model --output_format=piece < test.es > test.32k.es
python thualign/scripts/shuffle_corpus.py --corpus corpus.32k.es corpus.32k.en
sed -i -e 's/<s>/<eos>/' -e 's/<s\/>/<pad>/' corpus.32k.en
sed -i -e 's/<s>/<eos>/' -e 's/<s\/>/<pad>/' validation.32k.en
sed -i -e 's/<s>/<eos>/' -e 's/<s\/>/<pad>/' test.32k.en
sed -i -e 's/<s>/<eos>/' -e 's/<s\/>/<pad>/' corpus.32k.es
sed -i -e 's/<s>/<eos>/' -e 's/<s\/>/<pad>/' validation.32k.es
sed -i -e 's/<s>/<eos>/' -e 's/<s\/>/<pad>/' test.32k.es
./scripts/train-mask-align/preprocess_europarl.sh
```

### Train the Model
Expand Down Expand Up @@ -105,33 +85,17 @@ Note you need a computer with a CUDA-capable GPU to train the model.

## Generating the Answer Alignments for NewsQA-es

Run these commands to generate the answer alignments for the NewsQA-es dataset. You should have a trained Mask-Align
Run the following script to generate the answer alignments for the NewsQA-es dataset. You should have a trained Mask-Align
model and have the `newsqa.csv` file.

```bash
./scripts/generate-alignments/remove_bad_rows.py
./scripts/generate-alignments/generate_files.py
spm_encode --model=en.model --output_format=piece < test.en > test.32k.en
spm_encode --model=es.model --output_format=piece < test.es > test.32k.es
./scripts/generate-alignments/process_answer_indexes.py
mkdir corpus-es
mv test.32k.en test.32k.es answers.en corpus-es/
./thualign/bin/generate.sh -s spanish -o output.txt
spm_decode --model=es.model --input_format=piece < output.txt > output-plain.txt
./scripts/generate-alignments/output_brackets_to_indexes.py
./scripts/generate-alignments/generate_alignments.sh
```

The following three files are generated:
The following four files are generated:

* `output-indexes.txt`: the indexes of the answers in Spanish.
* `output-answers.txt`: the answers in Spanish (in plain text).
* `output-sentences.txt`: the sentences in Spanish (not tokenized).
* `newsqa-es.csv`: a new version of `newsqa_filtered.csv` which has the columns with the answers in Spanish.

### Generate the final merged CSV file

Finally, run these commands to generate the `newsqa-es.csv` file, a new version of `newsqa_filtered.csv` which has the columns with the answers in Spanish.

```bash
sed -i '1ianswer_index_esp' output-indexes.txt
csvjoin -y 0 newsqa_filtered.csv output-indexes.txt > newsqa-es.csv
```
14 changes: 14 additions & 0 deletions scripts/generate-alignments/generate_alignments.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#!/usr/bin/env bash
set -ex
python ./scripts/generate-alignments/remove_bad_rows.py
python ./scripts/generate-alignments/generate_files.py
spm_encode --model=en.model --output_format=piece < test.en > test.32k.en
spm_encode --model=es.model --output_format=piece < test.es > test.32k.es
python ./scripts/generate-alignments/process_answer_indexes.py
mkdir -p corpus-es
mv test.32k.en test.32k.es answers.en vocab.32k.es.txt vocab.32k.en.txt corpus-es/
./thualign/bin/generate.sh -s spanish -o output.txt
spm_decode --model=es.model --input_format=piece < output.txt > output-plain.txt
python ./scripts/generate-alignments/output_brackets_to_indexes.py
sed -i '1ianswer_index_esp' output-indexes.txt
csvjoin -y 0 newsqa_filtered.csv output-indexes.txt > newsqa-es.csv
60 changes: 29 additions & 31 deletions scripts/generate-alignments/remove_bad_rows.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,35 @@
#!/usr/bin/env python
# This script is used to filter the rows with bad data in newsqa.csv, creating the newsqa_filtered.csv file
import csv
import os
import random
import re
import csv

### This script is used to filter the rows with bad data in newsqa.csv, creating the newsqa_filtered.csv file

csv_input = open("newsqa.csv", "r", encoding="utf8")
csv_input2 = open("newsqa.csv", "r", encoding="utf8")
csv_data = list(csv.reader(csv_input, delimiter=','))
csv_lines = csv_input2.readlines()

csv_output = open("newsqa_filtered.csv", "w", encoding="utf8")
csv_output2 = open("newsqa_bad_rows.csv", "w", encoding="utf8")

def main() -> None:
with open("newsqa.csv", encoding="utf8") as csv_input, \
open("newsqa.csv", encoding="utf8") as csv_input2, \
open("newsqa_filtered.csv", "w", encoding="utf8") as csv_output, \
open("newsqa_bad_rows.csv", "w", encoding="utf8") as csv_output2:
for i, (row, csv_line) in enumerate(zip(csv.reader(csv_input), csv_input2)):
if i == 0:
csv_output.write(csv_line)
csv_output2.write(csv_line)
else:
ans_start = -1
ans_end = -1
if indexes := re.search(r"\d+:\d+", row[5]):
start_end_str = indexes.group(0).split(":")
ans_start = int(start_end_str[0])
ans_end = int(start_end_str[1])
csv_output.write(csv_lines[0])
csv_output2.write(csv_lines[0])

if any(re.match(r"^(\s|\t|\n|\r)*$", str(row[j])) for j in range(6)) is None \
and "*" not in str(row[2]) \
and "Ã" not in str(row[3]) \
and i not in {33677, 33676, 116925, 116926} \
and ans_start > -1 \
and ans_end > -1:
output_file = csv_output
else:
output_file = csv_output2
output_file.write(csv_line)
for idx, entry in enumerate(csv_data[1:]):
ans_start = -1
ans_end = -1
indexes = re.search(r"\d+:\d+", entry[5])
if indexes:
ans_start = int(indexes.group(0).split(":")[0])
ans_end = int(indexes.group(0).split(":")[1])

if re.match(r"^(\s|\t|\n|\r)*$", str(entry[0])) == None and re.match(r"^(\s|\t|\n|\r)*$", str(entry[1])) == None and re.match(r"^(\s|\t|\n|\r)*$", str(entry[2])) == None and re.match(r"^(\s|\t|\n|\r)*$", str(entry[3])) == None and re.match(r"^(\s|\t|\n|\r)*$", str(entry[4])) == None and re.match(r"^(\s|\t|\n|\r)*$", str(entry[5])) == None and "*" not in str(entry[2]) and "Ã" not in str(entry[3]) and idx != 33676 and idx != 33677 and idx != 116925 and idx != 116926 and ans_start > -1 and ans_end > -1:
csv_output.write(csv_lines[idx + 1])
else:
csv_output2.write(csv_lines[idx + 1])

if __name__ == "__main__":
main()
csv_input.close()
csv_input2.close()
csv_output.close()
csv_output2.close()
25 changes: 25 additions & 0 deletions scripts/train-mask-align/preprocess_europarl.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#!/usr/bin/env bash
set -ex
wget -qO- https://www.statmt.org/europarl/v7/es-en.tgz
tar zxvf es-en.tgz
python scripts/train-mask-align/split_corpus.py
spm_train --input=corpus.en --model_prefix=en --vocab_size=32000 --character_coverage=1.0 --model_type=unigram
spm_train --input=corpus.es --model_prefix=es --vocab_size=32000 --character_coverage=1.0 --model_type=unigram
python scripts/train-mask-align/process_vocab.py
sed -i 's/<s>/<pad>/g' vocab.32k.es.txt
sed -i 's/<\/s>/<eos>/g' vocab.32k.es.txt
sed -i 's/<s>/<pad>/g' vocab.32k.en.txt
sed -i 's/<\/s>/<eos>/g' vocab.32k.en.txt
spm_encode --model=en.model --output_format=piece < corpus.en > corpus.32k.en
spm_encode --model=en.model --output_format=piece < validation.en > validation.32k.en
spm_encode --model=en.model --output_format=piece < test.en > test.32k.en
spm_encode --model=es.model --output_format=piece < corpus.es > corpus.32k.es
spm_encode --model=es.model --output_format=piece < validation.es > validation.32k.es
spm_encode --model=es.model --output_format=piece < test.es > test.32k.es
python thualign/scripts/shuffle_corpus.py --corpus corpus.32k.es corpus.32k.en
sed -i -e 's/<s>/<eos>/' -e 's/<s\/>/<pad>/' corpus.32k.en
sed -i -e 's/<s>/<eos>/' -e 's/<s\/>/<pad>/' validation.32k.en
sed -i -e 's/<s>/<eos>/' -e 's/<s\/>/<pad>/' test.32k.en
sed -i -e 's/<s>/<eos>/' -e 's/<s\/>/<pad>/' corpus.32k.es
sed -i -e 's/<s>/<eos>/' -e 's/<s\/>/<pad>/' validation.32k.es
sed -i -e 's/<s>/<eos>/' -e 's/<s\/>/<pad>/' test.32k.es
42 changes: 42 additions & 0 deletions scripts/train-mask-align/split_corpus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import os
import re
import random

# For the corpus https://opus.nlpl.eu/download.php?f=WikiMatrix/v1/tmx/en-es.tmx.gz

corpus_es = open("./corpus.es", "w", encoding="utf-8")
corpus_en = open("./corpus.en", "w", encoding="utf-8")
validation_es = open("./validation.es", "w", encoding="utf-8")
validation_en = open("./validation.en", "w", encoding="utf-8")
test_es = open("./test.es", "w", encoding="utf-8")
test_en = open("./test.en", "w", encoding="utf-8")

file1 = open("./europarl-v7.es-en.es", "r", encoding="utf-8")
file2 = open("./europarl-v7.es-en.en", "r", encoding="utf-8")
data1 = file1.readlines()
data2 = file2.readlines()
num_sentences = min(len(data1), len(data2))

for i in range(0, num_sentences):
num_words1 = len(data1[i].split(' '))
num_words2 = len(data2[i].split(' '))
if re.search(r'\w+', data1[i]) and re.search(r'\w+', data2[i]) and "<" not in data1[i] and "<" not in data2[i] and num_words1 > 1 and num_words1 < 120 and num_words2 > 1 and num_words2 < 120:
random_number = random.uniform(0, 1)
if random_number < 0.9989:
corpus_es.write(data1[i].lower())
corpus_en.write(data2[i].lower())
elif random_number < 0.9978:
validation_es.write(data1[i].lower())
validation_en.write(data2[i].lower())
else:
test_es.write(data1[i].lower())
test_en.write(data2[i].lower())

file1.close()
file2.close()
corpus_es.close()
corpus_en.close()
validation_es.close()
validation_en.close()
test_es.close()
test_en.close()

0 comments on commit 840cf3f

Please sign in to comment.