pvyas96 commited on
Commit
ef4096a
·
verified ·
1 Parent(s): 8f7ff34

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -52
app.py CHANGED
@@ -1,52 +1,52 @@
1
- import streamlit as st
2
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
-
4
- # Create a class for the session state
5
- class SessionState:
6
- def __init__(self):
7
- self.conversation_history = []
8
-
9
- # Initialize the session state
10
- session_state = SessionState()
11
-
12
- # Sidebar for setting parameters
13
- st.sidebar.title("Model Parameters")
14
- # You can add more parameters here as needed
15
- max_length = st.sidebar.slider("Max Length", 10, 100, 50)
16
- temperature = st.sidebar.slider("Temperature", 0.0, 1.0, 0.7)
17
-
18
- # Load the model and tokenizer with a loading message
19
- with st.spinner('Wait for it... the model is loading'):
20
- model_name = "facebook/blenderbot-400M-distill"
21
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
22
- tokenizer = AutoTokenizer.from_pretrained(model_name)
23
-
24
- # Create a chat input for the user
25
- input_text = st.chat_input("Enter your message:")
26
-
27
- # Check if the user has entered a message
28
- if input_text:
29
- # Add the user's message to the conversation history
30
- session_state.conversation_history.append(("User", input_text))
31
-
32
- # Create conversation history string
33
- history_string = "\n".join(message for role, message in session_state.conversation_history)
34
-
35
- # Tokenize the input text and history
36
- inputs = tokenizer.encode_plus(history_string, input_text, return_tensors="pt")
37
-
38
- # Generate the response from the model with additional parameters
39
- outputs = model.generate(**inputs, max_length=max_length, do_sample=True ,temperature=temperature)
40
-
41
- # Decode the response
42
- response = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
43
-
44
- # Add the model's response to the conversation history
45
- session_state.conversation_history.append(("Assistant", response))
46
-
47
- # Display the conversation history using st.chat
48
- for role, message in session_state.conversation_history:
49
- if role == "User":
50
- st.chat(message, is_user=True)
51
- else:
52
- st.chat(message, is_user=False)
 
1
+ import streamlit as st
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
+
4
+ # Create a class for the session state
5
+ class SessionState:
6
+ def __init__(self):
7
+ self.conversation_history = []
8
+
9
+ # Initialize the session state
10
+ session_state = SessionState()
11
+
12
+ # Sidebar for setting parameters
13
+ st.sidebar.title("Model Parameters")
14
+ # You can add more parameters here as needed
15
+ max_length = st.sidebar.slider("Max Length", 10, 100, 50)
16
+ temperature = st.sidebar.slider("Temperature", 0.0, 1.0, 0.7)
17
+
18
+ # Load the model and tokenizer with a loading message
19
+ with st.spinner('Wait for it... the model is loading'):
20
+ model_name = "facebook/blenderbot-400M-distill"
21
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
22
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
23
+
24
+ # Create a chat input for the user
25
+ input_text = st.chat_input("Enter your message:")
26
+
27
+ # Check if the user has entered a message
28
+ if input_text:
29
+ # Add the user's message to the conversation history
30
+ session_state.conversation_history.append(("User", input_text))
31
+
32
+ # Create conversation history string
33
+ history_string = "\n".join(message for role, message in session_state.conversation_history)
34
+
35
+ # Tokenize the input text and history
36
+ inputs = tokenizer.encode_plus(history_string, input_text, return_tensors="pt")
37
+
38
+ # Generate the response from the model with additional parameters
39
+ outputs = model.generate(**inputs, max_length=max_length, do_sample=True ,temperature=temperature)
40
+
41
+ # Decode the response
42
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
43
+
44
+ # Add the model's response to the conversation history
45
+ session_state.conversation_history.append(("Assistant", response))
46
+
47
+ # Display the conversation history using st.chat
48
+ for role, message in session_state.conversation_history:
49
+ if role == "User":
50
+ st.chat_message(message, is_user=True)
51
+ else:
52
+ st.chat_message(message, is_user=False)