henok3878
add model training dataset details and usage notes
15b9e77
import torch
import gradio as gr
from tokenizers import Tokenizer
from transformer.config import load_config
from transformer.components.decoding import beam_search
from transformer.transformer import Transformer
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CONFIG_PATH = "configs/config.yaml"
MODEL_PATH = "model_checkpoint.pt"
TOKENIZER_PATH = "tokenizers/tokenizer-joint-de-en-vocab37000.json"
MAX_LEN = 128
config = load_config(CONFIG_PATH)
tokenizer = Tokenizer.from_file(TOKENIZER_PATH)
padding_idx = tokenizer.token_to_id("[PAD]")
model = Transformer.load_from_checkpoint(checkpoint_path=MODEL_PATH, config=config, device=DEVICE)
def translate(text: str, beam_size: int = 4) -> str:
src_ids = torch.tensor([tokenizer.encode(text).ids], device=DEVICE)
src_mask = (src_ids != padding_idx).unsqueeze(1).unsqueeze(2)
with torch.no_grad():
result_ids = beam_search(
model,
src_ids,
src_mask,
tokenizer,
max_len=MAX_LEN,
beam_size=beam_size,
)[0]
return tokenizer.decode(result_ids, skip_special_tokens=True)
with gr.Blocks(title="Transformer From Scratch Translation Demo") as demo:
gr.Markdown(
"# Transformer From Scratch Translation Demo\n"
"Translate English to German using a custom Transformer model trained from scratch.\n\n"
"**Note:** This model was trained on the WMT14 English-German news dataset. It works best on formal, news-style sentences and may not perform well on everyday informal or conversational text."
)
with gr.Row(equal_height=True):
with gr.Column():
input_text = gr.Textbox(
label="English Text",
placeholder="Enter text to translate...",
lines=3
)
beam_size = gr.Slider(
minimum=1, maximum=8, step=1, value=4, label="Beam Size"
)
with gr.Column():
output_text = gr.Textbox(
label="German Translation",
lines=3,
interactive=False,
show_copy_button=True,
show_label=True
)
with gr.Row():
with gr.Column(scale=1):
pass
with gr.Column(scale=2, min_width=300, elem_id="centered-controls"):
translate_btn = gr.Button("Translate")
gr.Examples(
examples=[
["Hello, how are you?"],
["The weather is nice today."],
["I love machine learning."],
],
inputs=[input_text]
)
with gr.Column(scale=1):
pass
translate_btn.click(
translate,
inputs=[input_text, beam_size],
outputs=[output_text]
)
if __name__ == "__main__":
demo.launch()