import gradio as gr import torch from transformers import AutoTokenizer from model import SmollM import yaml device = "cuda" if torch.cuda.is_available() else "cpu" with open("config.yaml", "r") as f: config = yaml.safe_load(f) ## Speed up with malmul torch.set_float32_matmul_precision('high') # Load model and tokenizer model = SmollM(config['model']['model_config']) # Load the tokenizer tokenizer = AutoTokenizer.from_pretrained(config['tokenizer']['tokenizer_name_or_path']) # Load your custom model (adjust as necessary for your model's implementation) model_path = "model.pt" # Replace with the path to your model weights checkpoint = torch.load(model_path, map_location=torch.device("cpu")) model.load_state_dict(checkpoint) model.eval() # Set the model to evaluation mode def generate_tokens(model, tokenizer, prompt, max_length=50, device="cuda"): """Generates output tokens based on a given prompt.""" model.eval() input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) with torch.no_grad(): outputs = input_ids for _ in range(max_length): logits = model(outputs[:, -1:]) next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True) outputs = torch.cat([outputs, next_token], dim=1) if next_token.item() == tokenizer.eos_token_id: break return tokenizer.decode(outputs[0], skip_special_tokens=True) # Function to tokenize input and generate text def generate_text(prompt, max_length=50): return generate_tokens(model, tokenizer, prompt, max_length, device) # Gradio interface with gr.Blocks() as demo: gr.Markdown("# SmoLLM-135M Text Generation Demo") gr.Markdown("Provide an input text prompt, and the model will generate text based on it.") with gr.Row(): input_text = gr.Textbox(label="Input Prompt", placeholder="Enter your text here...", lines=2) max_len = gr.Slider(label="Max Output Length", minimum=10, maximum=100, value=50, step=5) output_text = gr.Textbox(label="Generated Text", lines=5) generate_button = gr.Button("Generate") generate_button.click(generate_text, inputs=[input_text, max_len], outputs=output_text) # Run the app if __name__ == "__main__": demo.launch()