pvyas96 commited on
Commit
e48510a
·
verified ·
1 Parent(s): 48e59ce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -20
app.py CHANGED
@@ -1,6 +1,5 @@
1
  import streamlit as st
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
- import torch
4
 
5
  # Create a class for the session state
6
  class SessionState:
@@ -22,37 +21,32 @@ with st.spinner('Wait for it... the model is loading'):
22
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
23
  tokenizer = AutoTokenizer.from_pretrained(model_name)
24
 
25
- # Create a chat input for the user
26
- input_text = st.chat_input("Enter your message:")
27
 
28
  # Check if the user has entered a message
29
  if input_text:
30
  # Add the user's message to the conversation history
31
- session_state.conversation_history.append(("User", input_text))
32
 
33
  # Create conversation history string
34
- history_string = "\n".join(message for role, message in session_state.conversation_history)
35
 
36
  # Tokenize the input text and history
37
- inputs = tokenizer.encode_plus(history_string, return_tensors="pt")
38
- inputs["input_ids"] = torch.cat([inputs["input_ids"], torch.tensor([[tokenizer.sep_token_id]])], dim=-1)
39
- inputs["attention_mask"] = torch.cat([inputs["attention_mask"], torch.tensor([[1]])], dim=-1)
40
- inputs = tokenizer.encode_plus(input_text, return_tensors="pt", add_special_tokens=False)
41
- inputs["input_ids"] = torch.cat([inputs["input_ids"], inputs["input_ids"]], dim=-1)
42
- inputs["attention_mask"] = torch.cat([inputs["attention_mask"], inputs["attention_mask"]], dim=-1)
43
 
44
  # Generate the response from the model with additional parameters
45
- outputs = model.generate(**inputs, max_length=max_length, do_sample=True, temperature=temperature)
46
 
47
  # Decode the response
48
  response = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
49
 
50
  # Add the model's response to the conversation history
51
- session_state.conversation_history.append(("Assistant", response))
52
-
53
- # Display the conversation history using st.chat
54
- for role, message in session_state.conversation_history:
55
- if role == "User":
56
- st.chat_message(message, is_user=True)
57
- else:
58
- st.chat_message(message, is_user=False)
 
1
  import streamlit as st
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
 
3
 
4
  # Create a class for the session state
5
  class SessionState:
 
21
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
22
  tokenizer = AutoTokenizer.from_pretrained(model_name)
23
 
24
+ # Create a text input for the user
25
+ input_text = st.text_input("Enter your message:")
26
 
27
  # Check if the user has entered a message
28
  if input_text:
29
  # Add the user's message to the conversation history
30
+ session_state.conversation_history.append(input_text)
31
 
32
  # Create conversation history string
33
+ history_string = "\n".join(session_state.conversation_history)
34
 
35
  # Tokenize the input text and history
36
+ inputs = tokenizer.encode_plus(history_string, input_text, return_tensors="pt")
 
 
 
 
 
37
 
38
  # Generate the response from the model with additional parameters
39
+ outputs = model.generate(**inputs, max_length=max_length, temperature=temperature)
40
 
41
  # Decode the response
42
  response = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
43
 
44
  # Add the model's response to the conversation history
45
+ session_state.conversation_history.append(response)
46
+
47
+ # Display the conversation history
48
+ st.write("Conversation History:")
49
+ for i in range(0, len(session_state.conversation_history), 2):
50
+ st.write("User: " + session_state.conversation_history[i])
51
+ if i+1 < len(session_state.conversation_history):
52
+ st.write("Assistant: " + session_state.conversation_history[i+1])