diff --git a/scripts/translate_preference_pairs_gtranslate.py b/scripts/translate_preference_pairs_gtranslate.py index 7b9d32f..6daabf6 100644 --- a/scripts/translate_preference_pairs_gtranslate.py +++ b/scripts/translate_preference_pairs_gtranslate.py @@ -1,21 +1,19 @@ import argparse import json import os -import random from google.cloud import translate_v2 as translate from datasets import load_dataset from tqdm import tqdm -# Steps to setup: -# 1. https://cloud.google.com/python/docs/setup#linux -# 2. https://cloud.google.com/sdk/docs/install - -os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "" - -def translate_texts(texts, client, target_lang_code): - results = client.translate(texts, target_language=target_lang_code) - return [result['translatedText'] for result in results] +os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "" +def translate_text_batch(texts, client, target_lang_code): + """ + Translates a batch of texts using Google Translate API. + """ + result = client.translate(texts, target_language=target_lang_code, format_="text") + # return [{"translatedText": "...."} for res in texts] + return [res['translatedText'] for res in result] def validate_columns(dataset, columns): for subset in dataset.keys(): @@ -23,62 +21,100 @@ def validate_columns(dataset, columns): if column not in dataset[subset].column_names: raise ValueError(f"Column '{column}' not found in subset '{subset}' of the dataset") +def create_batches(dataset, batch_size): + """ + Create batches of examples from the dataset. + """ + batches = [] + current_batch = [] + + for example in dataset: + current_batch.append(example) + if len(current_batch) == batch_size: + batches.append(current_batch) + current_batch = [] + + if current_batch: + batches.append(current_batch) + + return batches + +def translate_subset(dataset, columns_to_translate, target_language, client, translate_prompt_only=False, batch_size=32): + translated_data = [] -def translate_dataset( - dataset, columns_to_translate, target_language, subset_size=None, output_dir="translations", batch_size=10 -): + # Create batches from the dataset + batches = create_batches(dataset, batch_size) + + for batch in tqdm(batches, desc=f"Translating subset"): + translated_batch = [] + + if translate_prompt_only: + # Collect all prompts for translation + prompts = [example['prompt'] for example in batch] + translated_prompts = translate_text_batch(prompts, client, target_language) + + for i, example in enumerate(batch): + translated_example = {'prompt': translated_prompts[i]} + # Copy other columns unchanged + for key in example.keys(): + if key != 'prompt': + translated_example[key] = example[key] + translated_example["target_language"] = target_language + translated_batch.append(translated_example) + else: + # Collect all texts for each column to be translated + for col in columns_to_translate: + texts_to_translate = [example[col] for example in batch] + translated_texts = translate_text_batch(texts_to_translate, client, target_language) + + for i, example in enumerate(batch): + if i >= len(translated_batch): + translated_batch.append({}) + translated_batch[i][col] = translated_texts[i] + + # Copy other columns as-is + for i, example in enumerate(batch): + for key in example.keys(): + if key not in translated_batch[i]: + translated_batch[i][key] = example[key] + translated_batch[i]["target_language"] = target_language + + translated_data.extend(translated_batch) + + return translated_data + +def translate_dataset(dataset, columns_to_translate, target_language, subset_size=None, output_dir="translations", batch_size=32): # Initialize the Google Cloud Translate client client = translate.Client() # Validate columns validate_columns(dataset, columns_to_translate) + + dataset = dataset["filtered"] if not os.path.exists(output_dir): os.makedirs(output_dir) - for subset in dataset.keys(): - if subset == "raw": - continue - translated_data = [] - data_length = len(dataset[subset]) - - # Randomly select a subset of the data if subset_size is specified - if subset_size: - indices = random.sample(range(data_length), min(subset_size, data_length)) - dataset[subset] = dataset[subset].select(indices) - - for start_idx in tqdm(range(0, data_length, batch_size), desc=f"Translating {subset} subset"): - end_idx = min(start_idx + batch_size, data_length) - batch = dataset[subset].select(range(start_idx, end_idx)) + # Filter for "hep-" subset and non-"hep-" subset + code_dataset = dataset.filter(lambda x: x['subset'].startswith('hep-')) + non_code_dataset = dataset.filter(lambda x: not x['subset'].startswith('hep-')) - # Initialize a dictionary to hold the translated batch - translated_batch = {col: [] for col in columns_to_translate} + # Translate non-"hep-" subset: Translate all columns + non_code_translated = translate_subset(non_code_dataset, columns_to_translate, target_language, client, batch_size=batch_size) - for col in columns_to_translate: - # Translate each column in the batch - texts_to_translate = batch[col] - translated_texts = translate_texts(texts_to_translate, client, target_language) - translated_batch[col] = translated_texts - - # Add other columns as-is - other_columns = {key: batch[key] for key in batch.column_names if key not in translated_batch} - - # Combine translated and other columns into a list of examples - for i in range(len(translated_batch[columns_to_translate[0]])): - translated_example = {col: translated_batch[col][i] for col in columns_to_translate} - translated_example["target_language"] = target_language - for key in other_columns: - translated_example[key] = other_columns[key][i] - translated_data.append(translated_example) + # Translate "hep-" subset: Translate only the 'prompt' column + code_translated = translate_subset(code_dataset, columns_to_translate, target_language, client, translate_prompt_only=True, batch_size=batch_size) - # Save translated data to JSON file - dataset_name = args.dataset_name.replace("/", "_") - output_file = os.path.join(output_dir, f"{dataset_name}_{subset}_{args.target_language}_translated.json") - with open(output_file, "w", encoding="utf-8") as f: - json.dump(translated_data, f, ensure_ascii=False, indent=4) + # Combine the translated data + combined_translated = code_translated + non_code_translated - print(f"Translated data for subset '{subset}' saved to {output_file}") + # Save the translated data + dataset_name = args.dataset_name.replace("/", "_") + output_file = os.path.join(output_dir, f"{dataset_name}_{args.target_language}_translated.json") + with open(output_file, "w", encoding="utf-8") as f: + json.dump(combined_translated, f, ensure_ascii=False, indent=4) + print(f"Translated data saved to {output_file}") if __name__ == "__main__": # fmt: off @@ -88,7 +124,7 @@ def translate_dataset( parser.add_argument("--columns_to_translate", type=str, nargs="+", required=True, help="Columns to translate.") parser.add_argument("--subset_size", type=int, help="Size of the random subset to translate.") parser.add_argument("--output_dir", type=str, default="translations", help="Output directory to save translations.") - parser.add_argument("--batch_size", type=int, default=10, help="Number of texts to translate in each batch.") + parser.add_argument("--batch_size", type=int, default=32, help="Batch size for Google Translate API.") # fmt: on args = parser.parse_args() @@ -103,5 +139,5 @@ def translate_dataset( args.target_language, args.subset_size, args.output_dir, - args.batch_size + args.batch_size, )