abdullah63's picture
Update app.py
6e99950 verified
raw
history blame
2.13 kB
import gradio as gr
import torch
import torch.nn as nn
import sentencepiece as spm
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load tokenizers
sp_pseudo = spm.SentencePieceProcessor(model_file="pseudo.model")
sp_code = spm.SentencePieceProcessor(model_file="code.model")
# Load the full saved model (architecture + weights)
model_path = "pseudo-to-cpp-model.pth" # Adjust path as needed
model = torch.load(model_path, map_location=device)
model.eval()
model = model.to(device)
def generate_code(pseudocode, max_len):
"""Generate C++ code from pseudocode with streaming output."""
model.eval()
src = torch.tensor([sp_pseudo.encode_as_ids(pseudocode)], dtype=torch.long, device=device)
tgt = torch.tensor([[2]], dtype=torch.long, device=device) # <bos_id>=2
generated_tokens = [2]
response = ""
with torch.no_grad():
for _ in range(max_len):
output = model(src, tgt)
next_token = output[:, -1, :].argmax(-1).item()
generated_tokens.append(next_token)
tgt = torch.cat([tgt, torch.tensor([[next_token]], device=device)], dim=1)
response = sp_code.decode_ids(generated_tokens)
yield response # Yield partial output
if next_token == 5: # <END> = 5
break
yield response # Final output
def respond(message, history, max_tokens):
"""Wrapper for Gradio interface."""
# Ignore history since it's one-shot generation
for response in generate_code(message, max_tokens):
yield response
# Gradio interface
demo = gr.ChatInterface(
respond,
chatbot=gr.Chatbot(label="Pseudocode to C++ Generator"),
textbox=gr.Textbox(placeholder="Enter pseudocode (e.g., 'for i from 1 to n, print i')", label="Pseudocode"),
additional_inputs=[
gr.Slider(minimum=10, maximum=1000, value=50, step=1, label="Max tokens"),
],
title="Pseudocode to C++ Transformer",
description="Convert pseudocode to C++ code using a custom transformer trained on the SPoC dataset.",
)
if __name__ == "__main__":
demo.launch()