Spaces:
Running
Running
import torch | |
import gradio as gr | |
from tokenizers import Tokenizer | |
from transformer.config import load_config | |
from transformer.components.decoding import beam_search | |
from transformer.transformer import Transformer | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
CONFIG_PATH = "configs/config.yaml" | |
MODEL_PATH = "model_checkpoint.pt" | |
TOKENIZER_PATH = "tokenizers/tokenizer-joint-de-en-vocab37000.json" | |
MAX_LEN = 128 | |
config = load_config(CONFIG_PATH) | |
tokenizer = Tokenizer.from_file(TOKENIZER_PATH) | |
padding_idx = tokenizer.token_to_id("[PAD]") | |
model = Transformer.load_from_checkpoint(checkpoint_path=MODEL_PATH, config=config, device=DEVICE) | |
def translate(text: str, beam_size: int = 4) -> str: | |
src_ids = torch.tensor([tokenizer.encode(text).ids], device=DEVICE) | |
src_mask = (src_ids != padding_idx).unsqueeze(1).unsqueeze(2) | |
with torch.no_grad(): | |
result_ids = beam_search( | |
model, | |
src_ids, | |
src_mask, | |
tokenizer, | |
max_len=MAX_LEN, | |
beam_size=beam_size, | |
)[0] | |
return tokenizer.decode(result_ids, skip_special_tokens=True) | |
with gr.Blocks(title="Transformer From Scratch Translation Demo") as demo: | |
gr.Markdown( | |
"# Transformer From Scratch Translation Demo\n" | |
"Translate English to German using a custom Transformer model trained from scratch.\n\n" | |
"**Note:** This model was trained on the WMT14 English-German news dataset. It works best on formal, news-style sentences and may not perform well on everyday informal or conversational text." | |
) | |
with gr.Row(equal_height=True): | |
with gr.Column(): | |
input_text = gr.Textbox( | |
label="English Text", | |
placeholder="Enter text to translate...", | |
lines=3 | |
) | |
beam_size = gr.Slider( | |
minimum=1, maximum=8, step=1, value=4, label="Beam Size" | |
) | |
with gr.Column(): | |
output_text = gr.Textbox( | |
label="German Translation", | |
lines=3, | |
interactive=False, | |
show_copy_button=True, | |
show_label=True | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
pass | |
with gr.Column(scale=2, min_width=300, elem_id="centered-controls"): | |
translate_btn = gr.Button("Translate") | |
gr.Examples( | |
examples=[ | |
["Hello, how are you?"], | |
["The weather is nice today."], | |
["I love machine learning."], | |
], | |
inputs=[input_text] | |
) | |
with gr.Column(scale=1): | |
pass | |
translate_btn.click( | |
translate, | |
inputs=[input_text, beam_size], | |
outputs=[output_text] | |
) | |
if __name__ == "__main__": | |
demo.launch() |