import streamlit as st from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import torch # Create a class for the session state class SessionState: def __init__(self): self.conversation_history = [] # Initialize the session state session_state = SessionState() # Sidebar for setting parameters st.sidebar.title("Model Parameters") # You can add more parameters here as needed max_length = st.sidebar.slider("Max Length", 10, 100, 50) temperature = st.sidebar.slider("Temperature", 0.0, 1.0, 0.7) # Load the model and tokenizer with a loading message with st.spinner('Wait for it... the model is loading'): model_name = "facebook/blenderbot-400M-distill" model = AutoModelForSeq2SeqLM.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) # Create a chat input for the user input_text = st.chat_input("Enter your message:") # Check if the user has entered a message if input_text: # Add the user's message to the conversation history session_state.conversation_history.append(("User", input_text)) # Create conversation history string history_string = "\n".join(message for role, message in session_state.conversation_history) # Tokenize the input text and history inputs = tokenizer.encode_plus(history_string, return_tensors="pt") inputs["input_ids"] = torch.cat([inputs["input_ids"], torch.tensor([[tokenizer.sep_token_id]])], dim=-1) inputs["attention_mask"] = torch.cat([inputs["attention_mask"], torch.tensor([[1]])], dim=-1) inputs = tokenizer.encode_plus(input_text, return_tensors="pt", add_special_tokens=False) inputs["input_ids"] = torch.cat([inputs["input_ids"], inputs["input_ids"]], dim=-1) inputs["attention_mask"] = torch.cat([inputs["attention_mask"], inputs["attention_mask"]], dim=-1) # Generate the response from the model with additional parameters outputs = model.generate(**inputs, max_length=max_length, do_sample=True, temperature=temperature) # Decode the response response = tokenizer.decode(outputs[0], skip_special_tokens=True).strip() # Add the model's response to the conversation history session_state.conversation_history.append(("Assistant", response)) # Display the conversation history using st.chat for role, message in session_state.conversation_history: if role == "User": st.chat_message(message, is_user=True) else: st.chat_message(message, is_user=False)