dlagpt / app.py
Lin Chen
GPT-2
f1beaad
import base64
import streamlit as st
from transformers import pipeline
# Set page configuration
st.set_page_config(
page_title="DLA GPT",
page_icon=":robot:",
layout="wide",
initial_sidebar_state="collapsed",
)
# Custom CSS to change the background color and add a logo
st.markdown(
"""
<style>
.stApp {
background-color: black;
color: white;
}
.logo-container {
margin-bottom: 20px;
}
.chat-container {
border-radius: 10px;
padding: 5px;
margin-top: 20px;
width: 100%;
}
.qa-container {
border-top: 1px dashed silver;
border-bottom: 1px dashed silver;
padding: 10px;
margin-top: 20px;
width: 100%;
}
.message {
margin-bottom: 15px;
padding: 10px;
border-radius: 5px;
}
.user-message {
background-color: #373749;
color: silver;
}
.bot-message {
background-color: #333333;
color: silver;
}
div[data-baseweb="input"] > div {
width: 100% !important; /* Make the input box take full width */
}
input[type="text"] {
height: auto;
padding: 10px;
white-space: normal;
overflow-wrap: break-word;
}
</style>
""",
unsafe_allow_html=True,
)
def get_image_base64(image_path):
with open(image_path, "rb") as img_file:
return base64.b64encode(img_file.read()).decode()
image_base64 = get_image_base64("./DLA.png")
# Add logo at the top
st.markdown(
f"""
<div class="logo-container">
<div style = "text-align:left; border: width: 15%; float:left"><img src="data:image/png;base64,{image_base64}" alt="Logo" style="width:80px;"></div>
<div style = "color:#FFD700; text-shadow: 5px 5px #99ccff; font-family:Cursive; font-size:48px; font-weight:bold; text-align:center; width: 80%; float:left">DLA GPT</div>
</div>
""",
unsafe_allow_html=True,
)
# Chatbot interaction logic
if "messages" not in st.session_state:
st.session_state.messages = []
def get_bot_response(prompt):
pipe = pipeline("text-generation", model="gpt2")
output = pipe(prompt, max_length=100)
return output[0]['generated_text']
def display_history():
# Chat interface
st.markdown("<div class='chat-container' style = 'color: #9999b2; font-weight:bold; font-style: italic;'>Here is the answer", unsafe_allow_html=True)
for message in st.session_state.messages[:10]:
st.markdown(f"<div class = 'qa-container'><div class='message user-message'><b>You</b>: {message[0]}</div><div class='message bot-message'><b>Bot</b>: {message[1]}</div></div>", unsafe_allow_html=True)
st.markdown("</div>", unsafe_allow_html=True)
def main():
user_input = st.text_input(" ", "")
if user_input:
bot_response = get_bot_response(user_input)
message_pair = (user_input, bot_response)
st.session_state.messages.insert(0, message_pair)
display_history()
if __name__ == '__main__':
main()