pvyas96's picture
Update app.py
69cfc6e verified
raw
history blame
1.81 kB
import streamlit as st
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# Create a class for the session state
class SessionState:
def __init__(self):
self.conversation_history = []
# Initialize the session state
session_state = SessionState()
# Sidebar for setting parameters
st.sidebar.title("Model Parameters")
# You can add more parameters here as needed
max_length = st.sidebar.slider("Max Length", 10, 100, 50)
temperature = st.sidebar.slider("Temperature", 0.0, 1.0, 0.7)
# Load the model and tokenizer with a loading message
with st.spinner('Wait for model to respond..'):
model_name = "llmware/bling-red-pajamas-3b-0.1"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Create a chat input for the user
input_text = st.chat_input("Enter your message:")
# Check if the user has entered a message
if input_text:
# Add the user's message to the conversation history
session_state.conversation_history.append(input_text)
# Display the user's message
st.write("**User:**", input_text)
# Create conversation history string
history_string = "\n".join(session_state.conversation_history)
# Tokenize the input text and history
inputs = tokenizer.encode_plus(history_string, input_text, return_tensors="pt")
# Generate the response from the model with additional parameters
outputs = model.generate(**inputs, max_length=max_length, temperature=temperature)
# Decode the response
response = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
# Add the model's response to the conversation history
session_state.conversation_history.append(response)
# Display the model's response
st.write("**Assistant:**", response)