IAMTFRMZA's picture
Update app.py
16745ff verified
raw
history blame
7.8 kB
import streamlit as st
import asyncio
import edge_tts
import time
import os
import re
import uuid
import firebase_admin
from firebase_admin import credentials, firestore
from openai import OpenAI
# ---- Firebase setup ----
if not firebase_admin._apps:
cred = credentials.Certificate("firebase-service-account.json")
firebase_admin.initialize_app(cred)
db = firestore.client()
# ---- OpenAI setup ----
openai_key = os.getenv("openai_key")
assistant_id = os.getenv("assistant_id")
client = OpenAI(api_key=openai_key)
# ---- Voice Settings ----
FIXED_VOICE_NAME = "Jenny (US, Female)"
FIXED_VOICE = "en-US-JennyNeural"
# --- State setup
if "user_id" not in st.session_state:
st.session_state["user_id"] = str(uuid.uuid4())
user_id = st.session_state["user_id"]
if "last_tts_text" not in st.session_state:
st.session_state["last_tts_text"] = ""
if "last_audio_path" not in st.session_state:
st.session_state["last_audio_path"] = ""
if "is_thinking" not in st.session_state:
st.session_state["is_thinking"] = False
# --- Page config ---
st.set_page_config(page_title="LOR Technologies AI Assistant", layout="wide")
# --- CSS Styling ---
st.markdown("""
<style>
.block-container {padding-top: 1rem;}
header {visibility: hidden;}
.logo-mini {
width: 75px !important;
margin: 0 auto 0.25em auto;
display: block;
}
.lor-brand-bar {
width: 100vw; text-align: center; background: none;
margin-bottom: 0.5em; margin-top: 0.1em;
}
.stChatMessage { max-width: 85%; border-radius: 12px; padding: 8px; margin-bottom: 10px; }
.stChatMessage[data-testid="stChatMessage-user"] { background: #f0f0f0; color: #000000; }
.stChatMessage[data-testid="stChatMessage-assistant"] { background: #e3f2fd; color: #000000; }
.chat-history-wrapper {
margin-top: 0.5em; margin-bottom: 5em; height: 65vh; overflow-y: auto; padding: 0 0.5em;
}
.chat-input-bar {
position: fixed; bottom: 0; width: 100%; z-index: 100;
background: #191b22; padding: 0.6em 1em; border-top: 1px solid #22232c;
display: flex; align-items: center; gap: 0.5em;
}
.chat-input-bar input { width: 100%; font-size: 1.1em; }
.clear-chat-btn { background: none; border: none; font-size: 1.4em; color: #999; cursor: pointer; }
</style>
""", unsafe_allow_html=True)
# --- Top Branding ---
st.markdown("""
<div class="lor-brand-bar">
<img src="https://lortechnologies.com/wp-content/uploads/2023/03/LOR-Online-Logo.svg" class="logo-mini" />
<div style="font-size: 13px; color: #888;">Powered by LOR Technologies</div>
</div>
""", unsafe_allow_html=True)
# --- Firestore helpers ---
def get_or_create_thread_id():
doc_ref = db.collection("users").document(user_id)
doc = doc_ref.get()
if doc.exists:
return doc.to_dict()["thread_id"]
else:
thread = client.beta.threads.create()
doc_ref.set({"thread_id": thread.id, "created_at": firestore.SERVER_TIMESTAMP})
return thread.id
def save_message(role, content):
db.collection("users").document(user_id).collection("messages").add({
"role": role,
"content": content,
"timestamp": firestore.SERVER_TIMESTAMP
})
def clear_chat_history():
user_doc_ref = db.collection("users").document(user_id)
for msg in user_doc_ref.collection("messages").stream():
msg.reference.delete()
user_doc_ref.delete()
st.session_state.clear()
st.rerun()
def display_chat_history():
messages = db.collection("users").document(user_id).collection("messages").order_by("timestamp").stream()
assistant_icon_html = "<img src='https://raw.githubusercontent.com/AndrewLORTech/lortechwebsite/main/lorain.jpg' width='22' style='vertical-align:middle; border-radius:50%;'/>"
chat_msgs = []
for msg in list(messages):
data = msg.to_dict()
if data["role"] == "user":
chat_msgs.append(
f"<div class='stChatMessage' data-testid='stChatMessage-user'>👤 <strong>You:</strong> {data['content']}</div>"
)
else:
chat_msgs.append(
f"<div class='stChatMessage' data-testid='stChatMessage-assistant'>{assistant_icon_html} <strong>LORAIN:</strong> {data['content']}</div>"
)
st.markdown('<div class="chat-history-wrapper">' + "".join(chat_msgs) + '</div>', unsafe_allow_html=True)
# --- Edge TTS synth ---
def sanitize_tts_text(text):
text = re.sub(r'[^\w\s\.\,\!\?\:\;\'\"]', '', text) # keep basic punctuation
text = text.replace('.co.za', 'dot coza')
return text
async def edge_tts_synthesize(text, voice, user_id):
out_path = f"output_{user_id}.mp3"
communicate = edge_tts.Communicate(text, voice)
await communicate.save(out_path)
return out_path
def synthesize_voice(text, user_id):
voice = FIXED_VOICE
sanitized = sanitize_tts_text(text)
out_path = f"output_{user_id}.mp3"
if st.session_state["last_tts_text"] != sanitized or not os.path.exists(out_path):
try:
with st.spinner(f"Generating voice ({FIXED_VOICE_NAME})..."):
asyncio.run(edge_tts_synthesize(sanitized, voice, user_id))
st.session_state["last_tts_text"] = sanitized
st.session_state["last_audio_path"] = out_path
except Exception as e:
st.warning(f"TTS Error: {e}")
return None
return out_path
# --- CHAT: display history ---
display_chat_history()
# --- LORAIN is thinking indicator ---
if st.session_state.get("is_thinking", False):
st.markdown("""
<div style="
text-align:center; color:#ddd; font-size: 14px;
margin-top: -1em; margin-bottom: 0.5em;">
🤖 <em>LORAIN is thinking...</em>
</div>
""", unsafe_allow_html=True)
# --- INPUT BAR (floating at bottom) ---
st.markdown('<div class="chat-input-bar">', unsafe_allow_html=True)
col1, col2 = st.columns([10, 1])
user_input = col1.chat_input("Type your message here...")
if col2.button("🗑️", help="Clear Chat", key="clear-chat-bottom"):
clear_chat_history()
st.markdown('</div>', unsafe_allow_html=True)
# --- PROCESS USER INPUT ---
if user_input:
st.session_state["is_thinking"] = True # start thinking
thread_id = get_or_create_thread_id()
client.beta.threads.messages.create(thread_id=thread_id, role="user", content=user_input)
save_message("user", user_input)
with st.spinner("Thinking and typing... 💭"):
run = client.beta.threads.runs.create(thread_id=thread_id, assistant_id=assistant_id)
while True:
run_status = client.beta.threads.runs.retrieve(thread_id=thread_id, run_id=run.id)
if run_status.status == "completed":
break
time.sleep(1)
messages_response = client.beta.threads.messages.list(thread_id=thread_id)
latest_response = sorted(messages_response.data, key=lambda x: x.created_at)[-1]
assistant_message = latest_response.content[0].text.value
save_message("assistant", assistant_message)
audio_path = synthesize_voice(assistant_message, user_id)
if audio_path and os.path.exists(audio_path):
st.audio(audio_path, format="audio/mp3", autoplay=True)
st.session_state["is_thinking"] = False # stop thinking
time.sleep(0.2)
st.rerun()
# --- Auto-scroll JS ---
st.markdown("""
<script>
window.onload = function() {
var chatWrapper = document.querySelector('.chat-history-wrapper');
if(chatWrapper){ chatWrapper.scrollTop = chatWrapper.scrollHeight; }
};
setTimeout(function(){
var chatWrapper = document.querySelector('.chat-history-wrapper');
if(chatWrapper){ chatWrapper.scrollTop = chatWrapper.scrollHeight; }
}, 300);
</script>
""", unsafe_allow_html=True)