Shahabmoin commited on
Commit
55ed7de
·
verified ·
1 Parent(s): 205f4e4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -32
app.py CHANGED
@@ -1,60 +1,65 @@
1
  import streamlit as st
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
3
 
4
- # Load the pre-trained model and tokenizer
5
  @st.cache_resource
6
  def load_model():
7
- model_name = "microsoft/DialoGPT-medium" # Replace with your preferred model
8
- model = AutoModelForCausalLM.from_pretrained(model_name)
9
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
 
 
 
 
10
  return model, tokenizer
11
 
12
  model, tokenizer = load_model()
13
 
14
- # Chat history
15
  if "messages" not in st.session_state:
16
  st.session_state["messages"] = []
17
 
18
  # Sidebar configuration
19
  st.sidebar.title("Chatbot Settings")
20
  st.sidebar.write("Customize your chatbot:")
21
- max_length = st.sidebar.slider("Max Response Length", 10, 200, 50)
22
  temperature = st.sidebar.slider("Response Creativity (Temperature)", 0.1, 1.0, 0.7)
23
 
24
  # App title
25
- st.title("🤖 Open Source Text-to-Text Chatbot")
26
 
27
- # Chat Interface
28
  st.write("### Chat with the bot:")
29
  user_input = st.text_input("You:", key="user_input", placeholder="Type your message here...")
30
 
31
  if user_input:
32
- # Encode the input and add chat history for context
33
- inputs = tokenizer.encode(
34
- " ".join(st.session_state["messages"] + [user_input]),
35
- return_tensors="pt",
36
- truncation=True
37
- )
38
-
39
  # Generate response
40
- response = model.generate(
41
- inputs,
42
- max_length=max_length,
43
- temperature=temperature,
44
- pad_token_id=tokenizer.eos_token_id,
45
- )
46
- bot_response = tokenizer.decode(response[:, inputs.shape[-1]:][0], skip_special_tokens=True)
47
-
48
- # Append to chat history
49
- st.session_state["messages"].append(f"You: {user_input}")
50
- st.session_state["messages"].append(f"Bot: {bot_response}")
51
-
52
- # Display the chat
53
- for message in st.session_state["messages"]:
54
- if message.startswith("You:"):
55
- st.markdown(f"**{message}**")
56
- else:
57
- st.markdown(f"> {message}")
58
 
59
  # Clear chat history button
60
  if st.button("Clear Chat"):
 
1
  import streamlit as st
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import torch
4
 
5
+ # Load the model and tokenizer
6
  @st.cache_resource
7
  def load_model():
8
+ model_name = "tiiuae/falcon-7b-instruct" # Replace with the desired Falcon model
 
9
  tokenizer = AutoTokenizer.from_pretrained(model_name)
10
+ model = AutoModelForCausalLM.from_pretrained(
11
+ model_name,
12
+ device_map="auto", # Automatically assign model layers to available GPUs/CPUs
13
+ torch_dtype=torch.float16 # Use FP16 for faster inference
14
+ )
15
  return model, tokenizer
16
 
17
  model, tokenizer = load_model()
18
 
19
+ # Initialize chat history
20
  if "messages" not in st.session_state:
21
  st.session_state["messages"] = []
22
 
23
  # Sidebar configuration
24
  st.sidebar.title("Chatbot Settings")
25
  st.sidebar.write("Customize your chatbot:")
26
+ max_length = st.sidebar.slider("Max Response Length (Tokens)", 50, 500, 150)
27
  temperature = st.sidebar.slider("Response Creativity (Temperature)", 0.1, 1.0, 0.7)
28
 
29
  # App title
30
+ st.title("🤖 Falcon Chatbot")
31
 
32
+ # Chat interface
33
  st.write("### Chat with the bot:")
34
  user_input = st.text_input("You:", key="user_input", placeholder="Type your message here...")
35
 
36
  if user_input:
37
+ # Add user input to chat history
38
+ st.session_state["messages"].append(f"User: {user_input}")
39
+
40
+ # Prepare input for the model
41
+ prompt = "\n".join(st.session_state["messages"]) + f"\nAssistant:"
42
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(model.device)
43
+
44
  # Generate response
45
+ with st.spinner("Thinking..."):
46
+ output = model.generate(
47
+ inputs.input_ids,
48
+ max_length=max_length,
49
+ temperature=temperature,
50
+ pad_token_id=tokenizer.eos_token_id,
51
+ )
52
+ bot_response = tokenizer.decode(output[0], skip_special_tokens=True).split("Assistant:")[-1].strip()
53
+
54
+ # Add bot response to chat history
55
+ st.session_state["messages"].append(f"Assistant: {bot_response}")
56
+
57
+ # Display chat history
58
+ for msg in st.session_state["messages"]:
59
+ if msg.startswith("User:"):
60
+ st.markdown(f"**{msg}**")
61
+ elif msg.startswith("Assistant:"):
62
+ st.markdown(f"> {msg}")
63
 
64
  # Clear chat history button
65
  if st.button("Clear Chat"):