Spaces:
Runtime error
Runtime error
from datasets import load_dataset | |
from transformers import MarianMTModel, MarianTokenizer, TrainingArguments, Trainer, DataCollatorForSeq2Seq | |
# Load dataset | |
dataset = load_dataset('csv', data_files='hindi_dataset.tsv', delimiter='\t') | |
# Load MarianMT tokenizer for translation task | |
tokenizer = MarianTokenizer.from_pretrained('Helsinki-NLP/opus-mt-en-hi') | |
# Tokenize the English text (source language) | |
def tokenize_function(examples): | |
return tokenizer(examples['english'], truncation=True, padding='max_length', max_length=128) | |
# Tokenize both English and Hindi sentences | |
tokenized_datasets = dataset.map(tokenize_function, batched=True) | |
def tokenize_target_function(examples): | |
return tokenizer(examples['hindi'], truncation=True, padding='max_length', max_length=128) | |
tokenized_datasets = tokenized_datasets.map(tokenize_target_function, batched=True) | |
# Data Collator for padding sequences | |
data_collator = DataCollatorForSeq2Seq(tokenizer, model=None) | |
# Load MarianMT model for translation | |
model = MarianMTModel.from_pretrained('Helsinki-NLP/opus-mt-en-hi') | |
# Define training arguments | |
training_args = TrainingArguments( | |
output_dir='./results', | |
evaluation_strategy="epoch", | |
learning_rate=2e-5, | |
per_device_train_batch_size=16, | |
per_device_eval_batch_size=16, | |
num_train_epochs=3, | |
weight_decay=0.01, | |
save_total_limit=2, | |
predict_with_generate=True, | |
) | |
# Initialize Trainer | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=tokenized_datasets['train'], | |
eval_dataset=tokenized_datasets['test'], | |
tokenizer=tokenizer, | |
data_collator=data_collator, | |
) | |
# Start training | |
trainer.train() | |
# Save the model | |
trainer.save_model('./my_hindi_translation_model') | |
# Evaluate the model | |
results = trainer.evaluate() | |
print(results) | |
# Generate a prediction | |
model.eval() | |
inputs = tokenizer("How are you?", return_tensors="pt") | |
outputs = model.generate(inputs["input_ids"], max_length=128) | |
print(tokenizer.decode(outputs[0], skip_special_tokens=True)) | |