Spaces:
Runtime error
Runtime error
import torch | |
from transformers import pipeline, AutoTokenizer | |
import gradio as gr | |
# Load models and tokenizer | |
models = [pipeline("text-generation", model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", device=-1) for _ in range(3)] | |
tokenizer = models[0].tokenizer | |
# Function for generating text using the ensemble of models | |
def generate_text(prompt): | |
messages = [ | |
{"role": "system", "content": "You are a friendly chatbot who always responds in the style of a pirate. Use pirate vocabulary and mannerisms in your replies."}, | |
{"role": "user", "content": prompt}, | |
] | |
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
responses = [] | |
for model in models: | |
with torch.no_grad(): | |
outputs = model(prompt, max_new_tokens=256, do_sample=True, temperature=0.7, top_k=50, top_p=0.95) | |
response = outputs[0]['generated_text'] | |
responses.append(response) | |
averaged_text = '' | |
for i in range(min(len(response) for response in responses)): | |
token_counts = {} | |
for response in responses: | |
token = response[i] | |
token_counts[token] = token_counts.get(token, 0) + 1 | |
most_frequent_tokens = sorted(token_counts.items(), key=lambda x: x[1], reverse=True) | |
averaged_token = most_frequent_tokens[0][0] # Choose the most frequent token | |
averaged_text += averaged_token | |
return averaged_text | |
# Define the Gradio block for the application | |
block = gr.Blocks() | |
with block: | |
input_text = gr.Textbox(lines=2, label="Enter your prompt") | |
output_text = gr.Textbox(label="Generated Text") | |
def update_output_text(input_text): | |
output_text.value = generate_text(input_text) | |
input_text.change(update_output_text, inputs=[input_text], outputs=[output_text]) | |
# Set up the Hugging Face Gradio App with custom styles (optional) | |
iface = gr.Interface.load("app::block", title="Pirate Chatbot", css="#gradio-container {font-family: 'Courier New', monospace;}") | |
# Launch the interface when running app.py directly | |
if __name__ == "__main__": | |
iface.launch() |