json-spa / app.py
kevalfst's picture
Update app.py
80fe950 verified
raw
history blame
1.18 kB
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch
import gradio as gr
model_id = "tiiuae/falcon-rw-1b"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0 if torch.cuda.is_available() else -1)
# Chat memory
chat_history = []
# Main chat function
def chat(user_input, history):
prompt = ""
for i, (user, bot) in enumerate(history):
prompt += f"User: {user}\nAssistant: {bot}\n"
prompt += f"User: {user_input}\nAssistant:"
response = pipe(prompt, max_new_tokens=256, do_sample=True, temperature=0.7)[0]["generated_text"]
answer = response.split("Assistant:")[-1].strip()
history.append((user_input, answer))
return history, history
# Interface
chatbot = gr.Chatbot()
demo = gr.Interface(fn=chat, inputs=["text", "state"], outputs=[chatbot, "state"],
title="ChatBot",
description="Normal chatbot like ChatGPT")
if __name__ == "__main__":
demo.launch()