Skip to content

Commit

Permalink
Merge pull request #33 from for-ai/gtranslate_updated
Browse files Browse the repository at this point in the history
Updated gtranslate with batch save and load
  • Loading branch information
RishabhMaheshwary authored Sep 9, 2024
2 parents 7143426 + 417eecb commit 122131c
Showing 1 changed file with 49 additions and 28 deletions.
77 changes: 49 additions & 28 deletions scripts/translate_preference_pairs_gtranslate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def translate_text_batch(texts, client, target_lang_code):
Translates a batch of texts using Google Translate API.
"""
result = client.translate(texts, target_language=target_lang_code, format_="text")
# return [{"translatedText": "...."} for res in texts]
# result = [ {"translatedText": "test"} for txt in texts]
return [res['translatedText'] for res in result]

def validate_columns(dataset, columns):
Expand All @@ -39,48 +39,81 @@ def create_batches(dataset, batch_size):

return batches

def translate_subset(dataset, columns_to_translate, target_language, client, translate_prompt_only=False, batch_size=32):
def load_processed_ids(output_file):
"""
Load the set of processed example ids from the output file.
"""
if os.path.exists(output_file):
with open(output_file, "r", encoding="utf-8") as f:
data = json.load(f)
return {example["id"] for example in data}
return set()

def save_translated_batch(translated_batch, output_file):
"""
Save the translated batch to the output file.
"""
if os.path.exists(output_file):
with open(output_file, "r", encoding="utf-8") as f:
existing_data = json.load(f)
else:
existing_data = []

existing_data.extend(translated_batch)

with open(output_file, "w", encoding="utf-8") as f:
json.dump(existing_data, f, ensure_ascii=False, indent=4)

def translate_subset(dataset, columns_to_translate, target_language, client, translate_prompt_only=False, batch_size=32, output_file=None):
translated_data = []

# Load previously processed ids
processed_ids = load_processed_ids(output_file)

# Create batches from the dataset
batches = create_batches(dataset, batch_size)

for batch in tqdm(batches, desc=f"Translating subset"):
translated_batch = []

# Skip batches with examples already processed
batch_to_process = [example for example in batch if example["id"] not in processed_ids]
if not batch_to_process:
print("Skipping current batch as it is already translated")
continue

if translate_prompt_only:
# Collect all prompts for translation
prompts = [example['prompt'] for example in batch]
prompts = [example['prompt'] for example in batch_to_process]
translated_prompts = translate_text_batch(prompts, client, target_language)

for i, example in enumerate(batch):
for i, example in enumerate(batch_to_process):
translated_example = {'prompt': translated_prompts[i]}
# Copy other columns unchanged
for key in example.keys():
if key != 'prompt':
translated_example[key] = example[key]
translated_example["target_language"] = target_language
translated_batch.append(translated_example)
else:
# Collect all texts for each column to be translated
for col in columns_to_translate:
texts_to_translate = [example[col] for example in batch]
texts_to_translate = [example[col] for example in batch_to_process]
translated_texts = translate_text_batch(texts_to_translate, client, target_language)

for i, example in enumerate(batch):
for i, example in enumerate(batch_to_process):
if i >= len(translated_batch):
translated_batch.append({})
translated_batch[i][col] = translated_texts[i]

# Copy other columns as-is
for i, example in enumerate(batch):
for i, example in enumerate(batch_to_process):
for key in example.keys():
if key not in translated_batch[i]:
translated_batch[i][key] = example[key]
translated_batch[i]["target_language"] = target_language

translated_data.extend(translated_batch)

# Save the translated batch
save_translated_batch(translated_batch, output_file)

return translated_data

def translate_dataset(dataset, columns_to_translate, target_language, subset_size=None, output_dir="translations", batch_size=32):
Expand All @@ -99,40 +132,28 @@ def translate_dataset(dataset, columns_to_translate, target_language, subset_siz
code_dataset = dataset.filter(lambda x: x['subset'].startswith('hep-'))
non_code_dataset = dataset.filter(lambda x: not x['subset'].startswith('hep-'))

# Translate non-"hep-" subset: Translate all columns
non_code_translated = translate_subset(non_code_dataset, columns_to_translate, target_language, client, batch_size=batch_size)

# Translate "hep-" subset: Translate only the 'prompt' column
code_translated = translate_subset(code_dataset, columns_to_translate, target_language, client, translate_prompt_only=True, batch_size=batch_size)

# Combine the translated data
combined_translated = code_translated + non_code_translated

# Save the translated data
dataset_name = args.dataset_name.replace("/", "_")
output_file = os.path.join(output_dir, f"{dataset_name}_{args.target_language}_translated.json")
with open(output_file, "w", encoding="utf-8") as f:
json.dump(combined_translated, f, ensure_ascii=False, indent=4)

print(f"Translated data saved to {output_file}")
# Translate non-"hep-" subset: Translate all columns
non_code_translated = translate_subset(non_code_dataset, columns_to_translate, target_language, client, batch_size=batch_size, output_file=output_file)

# Translate "hep-" subset: Translate only the 'prompt' column
code_translated = translate_subset(code_dataset, columns_to_translate, target_language, client, translate_prompt_only=True, batch_size=batch_size, output_file=output_file)

if __name__ == "__main__":
# fmt: off
parser = argparse.ArgumentParser(description="Translate dataset columns using Google Cloud Translate API.")
parser.add_argument("--dataset_name", type=str, required=True, help="Hugging Face dataset name.")
parser.add_argument("--target_language", type=str, required=True, help="Target language code (e.g., fr).")
parser.add_argument("--columns_to_translate", type=str, nargs="+", required=True, help="Columns to translate.")
parser.add_argument("--subset_size", type=int, help="Size of the random subset to translate.")
parser.add_argument("--output_dir", type=str, default="translations", help="Output directory to save translations.")
parser.add_argument("--batch_size", type=int, default=32, help="Batch size for Google Translate API.")
# fmt: on

args = parser.parse_args()

# Load dataset
dataset = load_dataset(args.dataset_name)

# Translate dataset
translate_dataset(
dataset,
args.columns_to_translate,
Expand Down

0 comments on commit 122131c

Please sign in to comment.