kevalfst commited on
Commit
58d5252
·
verified ·
1 Parent(s): 80fe950

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -23
app.py CHANGED
@@ -1,33 +1,21 @@
1
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
2
- import torch
3
  import gradio as gr
 
4
 
5
- model_id = "tiiuae/falcon-rw-1b"
6
  tokenizer = AutoTokenizer.from_pretrained(model_id)
7
- model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
8
  pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0 if torch.cuda.is_available() else -1)
9
 
10
- # Chat memory
11
- chat_history = []
12
-
13
- # Main chat function
14
  def chat(user_input, history):
15
  prompt = ""
16
- for i, (user, bot) in enumerate(history):
17
- prompt += f"User: {user}\nAssistant: {bot}\n"
18
- prompt += f"User: {user_input}\nAssistant:"
19
-
20
- response = pipe(prompt, max_new_tokens=256, do_sample=True, temperature=0.7)[0]["generated_text"]
21
- answer = response.split("Assistant:")[-1].strip()
22
-
23
- history.append((user_input, answer))
24
  return history, history
25
 
26
- # Interface
27
- chatbot = gr.Chatbot()
28
- demo = gr.Interface(fn=chat, inputs=["text", "state"], outputs=[chatbot, "state"],
29
- title="ChatBot",
30
- description="Normal chatbot like ChatGPT")
31
-
32
- if __name__ == "__main__":
33
- demo.launch()
 
1
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
 
2
  import gradio as gr
3
+ import torch
4
 
5
+ model_id = "tiiuae/falcon-rw-1b" # small enough to run in Hugging Face Space
6
  tokenizer = AutoTokenizer.from_pretrained(model_id)
7
+ model = AutoModelForCausalLM.from_pretrained(model_id)
8
  pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0 if torch.cuda.is_available() else -1)
9
 
 
 
 
 
10
  def chat(user_input, history):
11
  prompt = ""
12
+ for user, bot in history:
13
+ prompt += f"User: {user}\nBot: {bot}\n"
14
+ prompt += f"User: {user_input}\nBot:"
15
+
16
+ response = pipe(prompt, max_new_tokens=128, do_sample=True, temperature=0.7)[0]["generated_text"]
17
+ reply = response.split("Bot:")[-1].strip()
18
+ history.append((user_input, reply))
 
19
  return history, history
20
 
21
+ gr.ChatInterface(chat, chatbot=gr.Chatbot(), title="Lightweight Chatbot").launch()