3-Tiny-Llamas / app.py
TuringsSolutions's picture
Update app.py
8f65806 verified
import torch
from transformers import pipeline, AutoTokenizer
import gradio as gr
# Load models and tokenizer
models = [pipeline("text-generation", model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", device=-1) for _ in range(3)]
tokenizer = models[0].tokenizer
# Function for generating text using the ensemble of models
def generate_text(prompt):
messages = [
{"role": "system", "content": "You are a friendly chatbot who always responds in the style of a pirate. Use pirate vocabulary and mannerisms in your replies."},
{"role": "user", "content": prompt},
]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
responses = []
for model in models:
with torch.no_grad():
outputs = model(prompt, max_new_tokens=256, do_sample=True, temperature=0.7, top_k=50, top_p=0.95)
response = outputs[0]['generated_text']
responses.append(response)
averaged_text = ''
for i in range(min(len(response) for response in responses)):
token_counts = {}
for response in responses:
token = response[i]
token_counts[token] = token_counts.get(token, 0) + 1
most_frequent_tokens = sorted(token_counts.items(), key=lambda x: x[1], reverse=True)
averaged_token = most_frequent_tokens[0][0] # Choose the most frequent token
averaged_text += averaged_token
return averaged_text
# Define the Gradio block for the application
block = gr.Blocks()
with block:
input_text = gr.Textbox(lines=2, label="Enter your prompt")
output_text = gr.Textbox(label="Generated Text")
def update_output_text(input_text):
output_text.value = generate_text(input_text)
input_text.change(update_output_text, inputs=[input_text], outputs=[output_text])
# Set up the Hugging Face Gradio App with custom styles (optional)
iface = gr.Interface.load("app::block", title="Pirate Chatbot", css="#gradio-container {font-family: 'Courier New', monospace;}")
# Launch the interface when running app.py directly
if __name__ == "__main__":
iface.launch()