File size: 1,180 Bytes
d41f38e
 
93009f8
 
80fe950
52642dd
d41f38e
 
93009f8
d41f38e
 
52642dd
d41f38e
 
 
 
 
 
52642dd
d41f38e
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch
import gradio as gr

model_id = "tiiuae/falcon-rw-1b"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0 if torch.cuda.is_available() else -1)

# Chat memory
chat_history = []

# Main chat function
def chat(user_input, history):
    prompt = ""
    for i, (user, bot) in enumerate(history):
        prompt += f"User: {user}\nAssistant: {bot}\n"
    prompt += f"User: {user_input}\nAssistant:"

    response = pipe(prompt, max_new_tokens=256, do_sample=True, temperature=0.7)[0]["generated_text"]
    answer = response.split("Assistant:")[-1].strip()

    history.append((user_input, answer))
    return history, history

# Interface
chatbot = gr.Chatbot()
demo = gr.Interface(fn=chat, inputs=["text", "state"], outputs=[chatbot, "state"], 
                    title="ChatBot", 
                    description="Normal chatbot like ChatGPT")

if __name__ == "__main__":
    demo.launch()