|
import gradio as gr |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
import torch |
|
|
|
|
|
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" |
|
print("Loading model...") |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForCausalLM.from_pretrained(model_name) |
|
print("Model loaded.") |
|
|
|
|
|
chat_history_ids = None |
|
chat_step = 0 |
|
|
|
|
|
def respond(message, history=[]): |
|
global chat_history_ids, chat_step |
|
|
|
|
|
new_input_ids = tokenizer.encode(message + tokenizer.eos_token, return_tensors="pt") |
|
|
|
|
|
bot_input_ids = ( |
|
torch.cat([chat_history_ids, new_input_ids], dim=-1) |
|
if chat_step > 0 else new_input_ids |
|
) |
|
|
|
|
|
chat_history_ids = model.generate( |
|
bot_input_ids, |
|
max_new_tokens=500, |
|
pad_token_id=tokenizer.eos_token_id, |
|
do_sample=True, |
|
top_k=50, |
|
top_p=0.95, |
|
temperature=0.8, |
|
) |
|
|
|
|
|
reply = tokenizer.decode( |
|
chat_history_ids[:, bot_input_ids.shape[-1]:][0], |
|
skip_special_tokens=True |
|
) |
|
|
|
chat_step += 1 |
|
return reply |
|
|
|
|
|
gr.ChatInterface(fn=respond, title="🧠 SmolLM Chatbot").launch(share=True) |