Spaces:
Sleeping
Sleeping
File size: 1,993 Bytes
ef4096a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
import streamlit as st
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# Create a class for the session state
class SessionState:
def __init__(self):
self.conversation_history = []
# Initialize the session state
session_state = SessionState()
# Sidebar for setting parameters
st.sidebar.title("Model Parameters")
# You can add more parameters here as needed
max_length = st.sidebar.slider("Max Length", 10, 100, 50)
temperature = st.sidebar.slider("Temperature", 0.0, 1.0, 0.7)
# Load the model and tokenizer with a loading message
with st.spinner('Wait for it... the model is loading'):
model_name = "facebook/blenderbot-400M-distill"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Create a chat input for the user
input_text = st.chat_input("Enter your message:")
# Check if the user has entered a message
if input_text:
# Add the user's message to the conversation history
session_state.conversation_history.append(("User", input_text))
# Create conversation history string
history_string = "\n".join(message for role, message in session_state.conversation_history)
# Tokenize the input text and history
inputs = tokenizer.encode_plus(history_string, input_text, return_tensors="pt")
# Generate the response from the model with additional parameters
outputs = model.generate(**inputs, max_length=max_length, do_sample=True ,temperature=temperature)
# Decode the response
response = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
# Add the model's response to the conversation history
session_state.conversation_history.append(("Assistant", response))
# Display the conversation history using st.chat
for role, message in session_state.conversation_history:
if role == "User":
st.chat_message(message, is_user=True)
else:
st.chat_message(message, is_user=False)
|