File size: 2,131 Bytes
e760098
6e99950
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e760098
6e99950
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e760098
 
6e99950
e760098
 
6e99950
 
e760098
6e99950
e760098
6e99950
 
e760098
 
 
6e99950
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import gradio as gr
import torch
import torch.nn as nn
import sentencepiece as spm

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load tokenizers
sp_pseudo = spm.SentencePieceProcessor(model_file="pseudo.model")
sp_code = spm.SentencePieceProcessor(model_file="code.model")

# Load the full saved model (architecture + weights)
model_path = "pseudo-to-cpp-model.pth"  # Adjust path as needed
model = torch.load(model_path, map_location=device)
model.eval()
model = model.to(device)


def generate_code(pseudocode, max_len):
    """Generate C++ code from pseudocode with streaming output."""
    model.eval()
    src = torch.tensor([sp_pseudo.encode_as_ids(pseudocode)], dtype=torch.long, device=device)
    tgt = torch.tensor([[2]], dtype=torch.long, device=device)  # <bos_id>=2
    
    generated_tokens = [2]
    response = ""
    with torch.no_grad():
        for _ in range(max_len):
            output = model(src, tgt)
            next_token = output[:, -1, :].argmax(-1).item()
            generated_tokens.append(next_token)
            tgt = torch.cat([tgt, torch.tensor([[next_token]], device=device)], dim=1)
            response = sp_code.decode_ids(generated_tokens)
            yield response  # Yield partial output
            if next_token == 5:  # <END> = 5
                break
    yield response  # Final output

def respond(message, history, max_tokens):
    """Wrapper for Gradio interface."""
    # Ignore history since it's one-shot generation
    for response in generate_code(message, max_tokens):
        yield response

# Gradio interface
demo = gr.ChatInterface(
    respond,
    chatbot=gr.Chatbot(label="Pseudocode to C++ Generator"),
    textbox=gr.Textbox(placeholder="Enter pseudocode (e.g., 'for i from 1 to n, print i')", label="Pseudocode"),
    additional_inputs=[
        gr.Slider(minimum=10, maximum=1000, value=50, step=1, label="Max tokens"),
    ],
    title="Pseudocode to C++ Transformer",
    description="Convert pseudocode to C++ code using a custom transformer trained on the SPoC dataset.",
)

if __name__ == "__main__":
    demo.launch()