Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| # Create a class for the session state | |
| class SessionState: | |
| def __init__(self): | |
| self.conversation_history = [] | |
| # Initialize the session state | |
| session_state = SessionState() | |
| # Sidebar for setting parameters | |
| st.sidebar.title("Model Parameters") | |
| # You can add more parameters here as needed | |
| max_length = st.sidebar.slider("Max Length", 10, 100, 50) | |
| temperature = st.sidebar.slider("Temperature", 0.0, 1.0, 0.7) | |
| # Load the model and tokenizer with a loading message | |
| with st.spinner('Wait for it... the model is loading'): | |
| model_name = "facebook/blenderbot-400M-distill" | |
| model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| # Create a chat input for the user | |
| input_text = st.chat_input("Enter your message:") | |
| # Check if the user has entered a message | |
| if input_text: | |
| # Add the user's message to the conversation history | |
| session_state.conversation_history.append(("User", input_text)) | |
| # Create conversation history string | |
| history_string = "\n".join(message for role, message in session_state.conversation_history) | |
| # Tokenize the input text and history | |
| inputs = tokenizer.encode_plus(history_string, input_text, return_tensors="pt") | |
| # Generate the response from the model with additional parameters | |
| outputs = model.generate(**inputs, max_length=max_length, do_sample=True ,temperature=temperature) | |
| # Decode the response | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True).strip() | |
| # Add the model's response to the conversation history | |
| session_state.conversation_history.append(("Assistant", response)) | |
| # Display the conversation history using st.chat | |
| for role, message in session_state.conversation_history: | |
| if role == "User": | |
| st.chat_message(message, is_user=True) | |
| else: | |
| st.chat_message(message, is_user=False) | |