Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated gtransalte to skip code and non-code text #32

Merged
merged 1 commit into from
Sep 6, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 89 additions & 53 deletions scripts/translate_preference_pairs_gtranslate.py
Original file line number Diff line number Diff line change
@@ -1,84 +1,120 @@
import argparse
import json
import os
import random
from google.cloud import translate_v2 as translate
from datasets import load_dataset
from tqdm import tqdm

# Steps to setup:
# 1. https://cloud.google.com/python/docs/setup#linux
# 2. https://cloud.google.com/sdk/docs/install

os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "<path_to_credentials_json>"

def translate_texts(texts, client, target_lang_code):
results = client.translate(texts, target_language=target_lang_code)
return [result['translatedText'] for result in results]
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "<path_to_credentials.json>"

def translate_text_batch(texts, client, target_lang_code):
"""
Translates a batch of texts using Google Translate API.
"""
result = client.translate(texts, target_language=target_lang_code, format_="text")
# return [{"translatedText": "...."} for res in texts]
return [res['translatedText'] for res in result]

def validate_columns(dataset, columns):
for subset in dataset.keys():
for column in columns:
if column not in dataset[subset].column_names:
raise ValueError(f"Column '{column}' not found in subset '{subset}' of the dataset")

def create_batches(dataset, batch_size):
"""
Create batches of examples from the dataset.
"""
batches = []
current_batch = []

for example in dataset:
current_batch.append(example)
if len(current_batch) == batch_size:
batches.append(current_batch)
current_batch = []

if current_batch:
batches.append(current_batch)

return batches

def translate_subset(dataset, columns_to_translate, target_language, client, translate_prompt_only=False, batch_size=32):
translated_data = []

def translate_dataset(
dataset, columns_to_translate, target_language, subset_size=None, output_dir="translations", batch_size=10
):
# Create batches from the dataset
batches = create_batches(dataset, batch_size)

for batch in tqdm(batches, desc=f"Translating subset"):
translated_batch = []

if translate_prompt_only:
# Collect all prompts for translation
prompts = [example['prompt'] for example in batch]
translated_prompts = translate_text_batch(prompts, client, target_language)

for i, example in enumerate(batch):
translated_example = {'prompt': translated_prompts[i]}
# Copy other columns unchanged
for key in example.keys():
if key != 'prompt':
translated_example[key] = example[key]
translated_example["target_language"] = target_language
translated_batch.append(translated_example)
else:
# Collect all texts for each column to be translated
for col in columns_to_translate:
texts_to_translate = [example[col] for example in batch]
translated_texts = translate_text_batch(texts_to_translate, client, target_language)

for i, example in enumerate(batch):
if i >= len(translated_batch):
translated_batch.append({})
translated_batch[i][col] = translated_texts[i]

# Copy other columns as-is
for i, example in enumerate(batch):
for key in example.keys():
if key not in translated_batch[i]:
translated_batch[i][key] = example[key]
translated_batch[i]["target_language"] = target_language

translated_data.extend(translated_batch)

return translated_data

def translate_dataset(dataset, columns_to_translate, target_language, subset_size=None, output_dir="translations", batch_size=32):
# Initialize the Google Cloud Translate client
client = translate.Client()

# Validate columns
validate_columns(dataset, columns_to_translate)

dataset = dataset["filtered"]

if not os.path.exists(output_dir):
os.makedirs(output_dir)

for subset in dataset.keys():
if subset == "raw":
continue
translated_data = []
data_length = len(dataset[subset])

# Randomly select a subset of the data if subset_size is specified
if subset_size:
indices = random.sample(range(data_length), min(subset_size, data_length))
dataset[subset] = dataset[subset].select(indices)

for start_idx in tqdm(range(0, data_length, batch_size), desc=f"Translating {subset} subset"):
end_idx = min(start_idx + batch_size, data_length)
batch = dataset[subset].select(range(start_idx, end_idx))
# Filter for "hep-" subset and non-"hep-" subset
code_dataset = dataset.filter(lambda x: x['subset'].startswith('hep-'))
non_code_dataset = dataset.filter(lambda x: not x['subset'].startswith('hep-'))

# Initialize a dictionary to hold the translated batch
translated_batch = {col: [] for col in columns_to_translate}
# Translate non-"hep-" subset: Translate all columns
non_code_translated = translate_subset(non_code_dataset, columns_to_translate, target_language, client, batch_size=batch_size)

for col in columns_to_translate:
# Translate each column in the batch
texts_to_translate = batch[col]
translated_texts = translate_texts(texts_to_translate, client, target_language)
translated_batch[col] = translated_texts

# Add other columns as-is
other_columns = {key: batch[key] for key in batch.column_names if key not in translated_batch}

# Combine translated and other columns into a list of examples
for i in range(len(translated_batch[columns_to_translate[0]])):
translated_example = {col: translated_batch[col][i] for col in columns_to_translate}
translated_example["target_language"] = target_language
for key in other_columns:
translated_example[key] = other_columns[key][i]
translated_data.append(translated_example)
# Translate "hep-" subset: Translate only the 'prompt' column
code_translated = translate_subset(code_dataset, columns_to_translate, target_language, client, translate_prompt_only=True, batch_size=batch_size)

# Save translated data to JSON file
dataset_name = args.dataset_name.replace("/", "_")
output_file = os.path.join(output_dir, f"{dataset_name}_{subset}_{args.target_language}_translated.json")
with open(output_file, "w", encoding="utf-8") as f:
json.dump(translated_data, f, ensure_ascii=False, indent=4)
# Combine the translated data
combined_translated = code_translated + non_code_translated

print(f"Translated data for subset '{subset}' saved to {output_file}")
# Save the translated data
dataset_name = args.dataset_name.replace("/", "_")
output_file = os.path.join(output_dir, f"{dataset_name}_{args.target_language}_translated.json")
with open(output_file, "w", encoding="utf-8") as f:
json.dump(combined_translated, f, ensure_ascii=False, indent=4)

print(f"Translated data saved to {output_file}")

if __name__ == "__main__":
# fmt: off
Expand All @@ -88,7 +124,7 @@ def translate_dataset(
parser.add_argument("--columns_to_translate", type=str, nargs="+", required=True, help="Columns to translate.")
parser.add_argument("--subset_size", type=int, help="Size of the random subset to translate.")
parser.add_argument("--output_dir", type=str, default="translations", help="Output directory to save translations.")
parser.add_argument("--batch_size", type=int, default=10, help="Number of texts to translate in each batch.")
parser.add_argument("--batch_size", type=int, default=32, help="Batch size for Google Translate API.")
# fmt: on

args = parser.parse_args()
Expand All @@ -103,5 +139,5 @@ def translate_dataset(
args.target_language,
args.subset_size,
args.output_dir,
args.batch_size
args.batch_size,
)