Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
import torch | |
# Create a class for the session state | |
class SessionState: | |
def __init__(self): | |
self.conversation_history = [] | |
# Initialize the session state | |
session_state = SessionState() | |
# Sidebar for setting parameters | |
st.sidebar.title("Model Parameters") | |
# You can add more parameters here as needed | |
max_length = st.sidebar.slider("Max Length", 10, 100, 50) | |
temperature = st.sidebar.slider("Temperature", 0.0, 1.0, 0.7) | |
# Load the model and tokenizer with a loading message | |
with st.spinner('Wait for it... the model is loading'): | |
model_name = "facebook/blenderbot-400M-distill" | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
# Create a chat input for the user | |
input_text = st.chat_input("Enter your message:") | |
# Check if the user has entered a message | |
if input_text: | |
# Add the user's message to the conversation history | |
session_state.conversation_history.append(("User", input_text)) | |
# Create conversation history string | |
history_string = "\n".join(message for role, message in session_state.conversation_history) | |
# Tokenize the input text and history | |
inputs = tokenizer.encode_plus(history_string, return_tensors="pt") | |
inputs["input_ids"] = torch.cat([inputs["input_ids"], torch.tensor([tokenizer.sep_token_id])], dim=-1) | |
inputs["attention_mask"] = torch.cat([inputs["attention_mask"], torch.tensor([1])], dim=-1) | |
inputs = tokenizer.encode_plus(input_text, return_tensors="pt", add_special_tokens=False) | |
inputs["input_ids"] = torch.cat([inputs["input_ids"], inputs["input_ids"]], dim=-1) | |
inputs["attention_mask"] = torch.cat([inputs["attention_mask"], inputs["attention_mask"]], dim=-1) | |
# Generate the response from the model with additional parameters | |
outputs = model.generate(**inputs, max_length=max_length, do_sample=True, temperature=temperature) | |
# Decode the response | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True).strip() | |
# Add the model's response to the conversation history | |
session_state.conversation_history.append(("Assistant", response)) | |
# Display the conversation history using st.chat | |
for role, message in session_state.conversation_history: | |
if role == "User": | |
st.chat_message(message, is_user=True) | |
else: | |
st.chat_message(message, is_user=False) | |