Spaces:
Paused
Paused
import base64 | |
import streamlit as st | |
from transformers import pipeline | |
# Set page configuration | |
st.set_page_config( | |
page_title="DLA GPT", | |
page_icon=":robot:", | |
layout="wide", | |
initial_sidebar_state="collapsed", | |
) | |
# Custom CSS to change the background color and add a logo | |
st.markdown( | |
""" | |
<style> | |
.stApp { | |
background-color: black; | |
color: white; | |
} | |
.logo-container { | |
margin-bottom: 20px; | |
} | |
.chat-container { | |
border-radius: 10px; | |
padding: 5px; | |
margin-top: 20px; | |
width: 100%; | |
} | |
.qa-container { | |
border-top: 1px dashed silver; | |
border-bottom: 1px dashed silver; | |
padding: 10px; | |
margin-top: 20px; | |
width: 100%; | |
} | |
.message { | |
margin-bottom: 15px; | |
padding: 10px; | |
border-radius: 5px; | |
} | |
.user-message { | |
background-color: #373749; | |
color: silver; | |
} | |
.bot-message { | |
background-color: #333333; | |
color: silver; | |
} | |
div[data-baseweb="input"] > div { | |
width: 100% !important; /* Make the input box take full width */ | |
} | |
input[type="text"] { | |
height: auto; | |
padding: 10px; | |
white-space: normal; | |
overflow-wrap: break-word; | |
} | |
</style> | |
""", | |
unsafe_allow_html=True, | |
) | |
def get_image_base64(image_path): | |
with open(image_path, "rb") as img_file: | |
return base64.b64encode(img_file.read()).decode() | |
image_base64 = get_image_base64("./DLA.png") | |
# Add logo at the top | |
st.markdown( | |
f""" | |
<div class="logo-container"> | |
<div style = "text-align:left; border: width: 15%; float:left"><img src="data:image/png;base64,{image_base64}" alt="Logo" style="width:80px;"></div> | |
<div style = "color:#FFD700; text-shadow: 5px 5px #99ccff; font-family:Cursive; font-size:48px; font-weight:bold; text-align:center; width: 80%; float:left">DLA GPT</div> | |
</div> | |
""", | |
unsafe_allow_html=True, | |
) | |
# Chatbot interaction logic | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
def get_bot_response(prompt): | |
pipe = pipeline("text-generation", model="gpt2") | |
output = pipe(prompt, max_length=100) | |
return output[0]['generated_text'] | |
def display_history(): | |
# Chat interface | |
st.markdown("<div class='chat-container' style = 'color: #9999b2; font-weight:bold; font-style: italic;'>Here is the answer", unsafe_allow_html=True) | |
for message in st.session_state.messages[:10]: | |
st.markdown(f"<div class = 'qa-container'><div class='message user-message'><b>You</b>: {message[0]}</div><div class='message bot-message'><b>Bot</b>: {message[1]}</div></div>", unsafe_allow_html=True) | |
st.markdown("</div>", unsafe_allow_html=True) | |
def main(): | |
user_input = st.text_input(" ", "") | |
if user_input: | |
bot_response = get_bot_response(user_input) | |
message_pair = (user_input, bot_response) | |
st.session_state.messages.insert(0, message_pair) | |
display_history() | |
if __name__ == '__main__': | |
main() | |