pvyas96's picture
Update app.py
96f9572 verified
raw
history blame
2.48 kB
import streamlit as st
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
# Create a class for the session state
class SessionState:
def __init__(self):
self.conversation_history = []
# Initialize the session state
session_state = SessionState()
# Sidebar for setting parameters
st.sidebar.title("Model Parameters")
# You can add more parameters here as needed
max_length = st.sidebar.slider("Max Length", 10, 100, 50)
temperature = st.sidebar.slider("Temperature", 0.0, 1.0, 0.7)
# Load the model and tokenizer with a loading message
with st.spinner('Wait for it... the model is loading'):
model_name = "facebook/blenderbot-400M-distill"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Create a chat input for the user
input_text = st.chat_input("Enter your message:")
# Check if the user has entered a message
if input_text:
# Add the user's message to the conversation history
session_state.conversation_history.append(("User", input_text))
# Create conversation history string
history_string = "\n".join(message for role, message in session_state.conversation_history)
# Tokenize the input text and history
inputs = tokenizer.encode_plus(history_string, return_tensors="pt")
inputs["input_ids"] = torch.cat([inputs["input_ids"], torch.tensor([tokenizer.sep_token_id])], dim=-1)
inputs["attention_mask"] = torch.cat([inputs["attention_mask"], torch.tensor([1])], dim=-1)
inputs = tokenizer.encode_plus(input_text, return_tensors="pt", add_special_tokens=False)
inputs["input_ids"] = torch.cat([inputs["input_ids"], inputs["input_ids"]], dim=-1)
inputs["attention_mask"] = torch.cat([inputs["attention_mask"], inputs["attention_mask"]], dim=-1)
# Generate the response from the model with additional parameters
outputs = model.generate(**inputs, max_length=max_length, do_sample=True, temperature=temperature)
# Decode the response
response = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
# Add the model's response to the conversation history
session_state.conversation_history.append(("Assistant", response))
# Display the conversation history using st.chat
for role, message in session_state.conversation_history:
if role == "User":
st.chat_message(message, is_user=True)
else:
st.chat_message(message, is_user=False)