Skip to content

Commit

Permalink
Merge pull request #32 from for-ai/gtranslate_updated
Browse files Browse the repository at this point in the history
Updated gtransalte to skip human-eval-pack chosen/reject; fixed format_ from html to text.
  • Loading branch information
ShayekhBinIslam authored Sep 6, 2024
2 parents 7dcad10 + e2d1086 commit 7143426
Showing 1 changed file with 89 additions and 53 deletions.
142 changes: 89 additions & 53 deletions scripts/translate_preference_pairs_gtranslate.py
Original file line number Diff line number Diff line change
@@ -1,84 +1,120 @@
import argparse
import json
import os
import random
from google.cloud import translate_v2 as translate
from datasets import load_dataset
from tqdm import tqdm

# Steps to setup:
# 1. https://cloud.google.com/python/docs/setup#linux
# 2. https://cloud.google.com/sdk/docs/install

os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "<path_to_credentials_json>"

def translate_texts(texts, client, target_lang_code):
results = client.translate(texts, target_language=target_lang_code)
return [result['translatedText'] for result in results]
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "<path_to_credentials.json>"

def translate_text_batch(texts, client, target_lang_code):
"""
Translates a batch of texts using Google Translate API.
"""
result = client.translate(texts, target_language=target_lang_code, format_="text")
# return [{"translatedText": "...."} for res in texts]
return [res['translatedText'] for res in result]

def validate_columns(dataset, columns):
for subset in dataset.keys():
for column in columns:
if column not in dataset[subset].column_names:
raise ValueError(f"Column '{column}' not found in subset '{subset}' of the dataset")

def create_batches(dataset, batch_size):
"""
Create batches of examples from the dataset.
"""
batches = []
current_batch = []

for example in dataset:
current_batch.append(example)
if len(current_batch) == batch_size:
batches.append(current_batch)
current_batch = []

if current_batch:
batches.append(current_batch)

return batches

def translate_subset(dataset, columns_to_translate, target_language, client, translate_prompt_only=False, batch_size=32):
translated_data = []

def translate_dataset(
dataset, columns_to_translate, target_language, subset_size=None, output_dir="translations", batch_size=10
):
# Create batches from the dataset
batches = create_batches(dataset, batch_size)

for batch in tqdm(batches, desc=f"Translating subset"):
translated_batch = []

if translate_prompt_only:
# Collect all prompts for translation
prompts = [example['prompt'] for example in batch]
translated_prompts = translate_text_batch(prompts, client, target_language)

for i, example in enumerate(batch):
translated_example = {'prompt': translated_prompts[i]}
# Copy other columns unchanged
for key in example.keys():
if key != 'prompt':
translated_example[key] = example[key]
translated_example["target_language"] = target_language
translated_batch.append(translated_example)
else:
# Collect all texts for each column to be translated
for col in columns_to_translate:
texts_to_translate = [example[col] for example in batch]
translated_texts = translate_text_batch(texts_to_translate, client, target_language)

for i, example in enumerate(batch):
if i >= len(translated_batch):
translated_batch.append({})
translated_batch[i][col] = translated_texts[i]

# Copy other columns as-is
for i, example in enumerate(batch):
for key in example.keys():
if key not in translated_batch[i]:
translated_batch[i][key] = example[key]
translated_batch[i]["target_language"] = target_language

translated_data.extend(translated_batch)

return translated_data

def translate_dataset(dataset, columns_to_translate, target_language, subset_size=None, output_dir="translations", batch_size=32):
# Initialize the Google Cloud Translate client
client = translate.Client()

# Validate columns
validate_columns(dataset, columns_to_translate)

dataset = dataset["filtered"]

if not os.path.exists(output_dir):
os.makedirs(output_dir)

for subset in dataset.keys():
if subset == "raw":
continue
translated_data = []
data_length = len(dataset[subset])

# Randomly select a subset of the data if subset_size is specified
if subset_size:
indices = random.sample(range(data_length), min(subset_size, data_length))
dataset[subset] = dataset[subset].select(indices)

for start_idx in tqdm(range(0, data_length, batch_size), desc=f"Translating {subset} subset"):
end_idx = min(start_idx + batch_size, data_length)
batch = dataset[subset].select(range(start_idx, end_idx))
# Filter for "hep-" subset and non-"hep-" subset
code_dataset = dataset.filter(lambda x: x['subset'].startswith('hep-'))
non_code_dataset = dataset.filter(lambda x: not x['subset'].startswith('hep-'))

# Initialize a dictionary to hold the translated batch
translated_batch = {col: [] for col in columns_to_translate}
# Translate non-"hep-" subset: Translate all columns
non_code_translated = translate_subset(non_code_dataset, columns_to_translate, target_language, client, batch_size=batch_size)

for col in columns_to_translate:
# Translate each column in the batch
texts_to_translate = batch[col]
translated_texts = translate_texts(texts_to_translate, client, target_language)
translated_batch[col] = translated_texts

# Add other columns as-is
other_columns = {key: batch[key] for key in batch.column_names if key not in translated_batch}

# Combine translated and other columns into a list of examples
for i in range(len(translated_batch[columns_to_translate[0]])):
translated_example = {col: translated_batch[col][i] for col in columns_to_translate}
translated_example["target_language"] = target_language
for key in other_columns:
translated_example[key] = other_columns[key][i]
translated_data.append(translated_example)
# Translate "hep-" subset: Translate only the 'prompt' column
code_translated = translate_subset(code_dataset, columns_to_translate, target_language, client, translate_prompt_only=True, batch_size=batch_size)

# Save translated data to JSON file
dataset_name = args.dataset_name.replace("/", "_")
output_file = os.path.join(output_dir, f"{dataset_name}_{subset}_{args.target_language}_translated.json")
with open(output_file, "w", encoding="utf-8") as f:
json.dump(translated_data, f, ensure_ascii=False, indent=4)
# Combine the translated data
combined_translated = code_translated + non_code_translated

print(f"Translated data for subset '{subset}' saved to {output_file}")
# Save the translated data
dataset_name = args.dataset_name.replace("/", "_")
output_file = os.path.join(output_dir, f"{dataset_name}_{args.target_language}_translated.json")
with open(output_file, "w", encoding="utf-8") as f:
json.dump(combined_translated, f, ensure_ascii=False, indent=4)

print(f"Translated data saved to {output_file}")

if __name__ == "__main__":
# fmt: off
Expand All @@ -88,7 +124,7 @@ def translate_dataset(
parser.add_argument("--columns_to_translate", type=str, nargs="+", required=True, help="Columns to translate.")
parser.add_argument("--subset_size", type=int, help="Size of the random subset to translate.")
parser.add_argument("--output_dir", type=str, default="translations", help="Output directory to save translations.")
parser.add_argument("--batch_size", type=int, default=10, help="Number of texts to translate in each batch.")
parser.add_argument("--batch_size", type=int, default=32, help="Batch size for Google Translate API.")
# fmt: on

args = parser.parse_args()
Expand All @@ -103,5 +139,5 @@ def translate_dataset(
args.target_language,
args.subset_size,
args.output_dir,
args.batch_size
args.batch_size,
)

0 comments on commit 7143426

Please sign in to comment.