File size: 2,684 Bytes
713ed4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import torch
import gradio as gr
from tokenizers import Tokenizer

from transformer.config import load_config
from transformer.components.decoding import beam_search
from transformer.transformer import Transformer

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

CONFIG_PATH = "configs/config.yaml"
MODEL_PATH = "model_checkpoint.pt"
TOKENIZER_PATH = "tokenizers/tokenizer-joint-de-en-vocab37000.json"
MAX_LEN = 128

config = load_config(CONFIG_PATH)
tokenizer = Tokenizer.from_file(TOKENIZER_PATH)
padding_idx = tokenizer.token_to_id("[PAD]")

model = Transformer.load_from_checkpoint(checkpoint_path=MODEL_PATH, config=config, device=DEVICE)

def translate(text: str, beam_size: int = 4) -> str:
    src_ids = torch.tensor([tokenizer.encode(text).ids], device=DEVICE)
    src_mask = (src_ids != padding_idx).unsqueeze(1).unsqueeze(2)
    with torch.no_grad():
        result_ids = beam_search(
            model,
            src_ids,
            src_mask,
            tokenizer,
            max_len=MAX_LEN,
            beam_size=beam_size,
        )[0]
    return tokenizer.decode(result_ids, skip_special_tokens=True)

with gr.Blocks(title="Transformer From Scratch Translation Demo") as demo:
    gr.Markdown(
        "# Transformer From Scratch Translation Demo\n"
        "Translate English to German using a custom Transformer model trained from scratch."
    )
    with gr.Row(equal_height=True):
        with gr.Column():
            input_text = gr.Textbox(
                label="English Text",
                placeholder="Enter text to translate...",
                lines=3
            )
            beam_size = gr.Slider(
                minimum=1, maximum=8, step=1, value=4, label="Beam Size"
            )
        with gr.Column():
            output_text = gr.Textbox(
                label="German Translation",
                lines=3,
                interactive=False,
                show_copy_button=True,
                show_label=True
            )
    with gr.Row():
        with gr.Column(scale=1):
            pass 
        with gr.Column(scale=2, min_width=300, elem_id="centered-controls"):
            translate_btn = gr.Button("Translate")
            gr.Examples(
                examples=[
                    ["Hello, how are you?"],
                    ["The weather is nice today."],
                    ["I love machine learning."],
                ],
                inputs=[input_text]
            )
        with gr.Column(scale=1):
            pass 

    translate_btn.click(
        translate,
        inputs=[input_text, beam_size],
        outputs=[output_text]
    )

if __name__ == "__main__":
    demo.launch()