Spaces:
Running
Running
import streamlit as st | |
import asyncio | |
import edge_tts | |
import time | |
import os | |
import re | |
import uuid | |
import firebase_admin | |
from firebase_admin import credentials, firestore | |
from openai import OpenAI | |
# ---- Firebase setup ---- | |
if not firebase_admin._apps: | |
cred = credentials.Certificate("firebase-service-account.json") | |
firebase_admin.initialize_app(cred) | |
db = firestore.client() | |
# ---- OpenAI setup ---- | |
openai_key = os.getenv("openai_key") | |
assistant_id = os.getenv("assistant_id") | |
client = OpenAI(api_key=openai_key) | |
# ---- Voice Settings ---- | |
FIXED_VOICE_NAME = "Jenny (US, Female)" | |
FIXED_VOICE = "en-US-JennyNeural" | |
# --- State setup | |
if "user_id" not in st.session_state: | |
st.session_state["user_id"] = str(uuid.uuid4()) | |
user_id = st.session_state["user_id"] | |
if "last_tts_text" not in st.session_state: | |
st.session_state["last_tts_text"] = "" | |
if "last_audio_path" not in st.session_state: | |
st.session_state["last_audio_path"] = "" | |
if "is_thinking" not in st.session_state: | |
st.session_state["is_thinking"] = False | |
# --- Page config --- | |
st.set_page_config(page_title="LOR Technologies AI Assistant", layout="wide") | |
# --- CSS Styling --- | |
st.markdown(""" | |
<style> | |
.block-container {padding-top: 1rem;} | |
header {visibility: hidden;} | |
.logo-mini { | |
width: 75px !important; | |
margin: 0 auto 0.25em auto; | |
display: block; | |
} | |
.lor-brand-bar { | |
width: 100vw; text-align: center; background: none; | |
margin-bottom: 0.5em; margin-top: 0.1em; | |
} | |
.stChatMessage { max-width: 85%; border-radius: 12px; padding: 8px; margin-bottom: 10px; } | |
.stChatMessage[data-testid="stChatMessage-user"] { background: #f0f0f0; color: #000000; } | |
.stChatMessage[data-testid="stChatMessage-assistant"] { background: #e3f2fd; color: #000000; } | |
.chat-history-wrapper { | |
margin-top: 0.5em; margin-bottom: 5em; height: 65vh; overflow-y: auto; padding: 0 0.5em; | |
} | |
.chat-input-bar { | |
position: fixed; bottom: 0; width: 100%; z-index: 100; | |
background: #191b22; padding: 0.6em 1em; border-top: 1px solid #22232c; | |
display: flex; align-items: center; gap: 0.5em; | |
} | |
.chat-input-bar input { width: 100%; font-size: 1.1em; } | |
.clear-chat-btn { background: none; border: none; font-size: 1.4em; color: #999; cursor: pointer; } | |
</style> | |
""", unsafe_allow_html=True) | |
# --- Top Branding --- | |
st.markdown(""" | |
<div class="lor-brand-bar"> | |
<img src="https://lortechnologies.com/wp-content/uploads/2023/03/LOR-Online-Logo.svg" class="logo-mini" /> | |
<div style="font-size: 13px; color: #888;">Powered by LOR Technologies</div> | |
</div> | |
""", unsafe_allow_html=True) | |
# --- Firestore helpers --- | |
def get_or_create_thread_id(): | |
doc_ref = db.collection("users").document(user_id) | |
doc = doc_ref.get() | |
if doc.exists: | |
return doc.to_dict()["thread_id"] | |
else: | |
thread = client.beta.threads.create() | |
doc_ref.set({"thread_id": thread.id, "created_at": firestore.SERVER_TIMESTAMP}) | |
return thread.id | |
def save_message(role, content): | |
db.collection("users").document(user_id).collection("messages").add({ | |
"role": role, | |
"content": content, | |
"timestamp": firestore.SERVER_TIMESTAMP | |
}) | |
def clear_chat_history(): | |
user_doc_ref = db.collection("users").document(user_id) | |
for msg in user_doc_ref.collection("messages").stream(): | |
msg.reference.delete() | |
user_doc_ref.delete() | |
st.session_state.clear() | |
st.rerun() | |
def display_chat_history(): | |
messages = db.collection("users").document(user_id).collection("messages").order_by("timestamp").stream() | |
assistant_icon_html = "<img src='https://raw.githubusercontent.com/AndrewLORTech/lortechwebsite/main/lorain.jpg' width='22' style='vertical-align:middle; border-radius:50%;'/>" | |
chat_msgs = [] | |
for msg in list(messages): | |
data = msg.to_dict() | |
if data["role"] == "user": | |
chat_msgs.append( | |
f"<div class='stChatMessage' data-testid='stChatMessage-user'>👤 <strong>You:</strong> {data['content']}</div>" | |
) | |
else: | |
chat_msgs.append( | |
f"<div class='stChatMessage' data-testid='stChatMessage-assistant'>{assistant_icon_html} <strong>LORAIN:</strong> {data['content']}</div>" | |
) | |
st.markdown('<div class="chat-history-wrapper">' + "".join(chat_msgs) + '</div>', unsafe_allow_html=True) | |
# --- Edge TTS synth --- | |
def sanitize_tts_text(text): | |
text = re.sub(r'[^\w\s\.\,\!\?\:\;\'\"]', '', text) # keep basic punctuation | |
text = text.replace('.co.za', 'dot coza') | |
return text | |
async def edge_tts_synthesize(text, voice, user_id): | |
out_path = f"output_{user_id}.mp3" | |
communicate = edge_tts.Communicate(text, voice) | |
await communicate.save(out_path) | |
return out_path | |
def synthesize_voice(text, user_id): | |
voice = FIXED_VOICE | |
sanitized = sanitize_tts_text(text) | |
out_path = f"output_{user_id}.mp3" | |
if st.session_state["last_tts_text"] != sanitized or not os.path.exists(out_path): | |
try: | |
with st.spinner(f"Generating voice ({FIXED_VOICE_NAME})..."): | |
asyncio.run(edge_tts_synthesize(sanitized, voice, user_id)) | |
st.session_state["last_tts_text"] = sanitized | |
st.session_state["last_audio_path"] = out_path | |
except Exception as e: | |
st.warning(f"TTS Error: {e}") | |
return None | |
return out_path | |
# --- CHAT: display history --- | |
display_chat_history() | |
# --- LORAIN is thinking indicator --- | |
if st.session_state.get("is_thinking", False): | |
st.markdown(""" | |
<div style=" | |
text-align:center; color:#ddd; font-size: 14px; | |
margin-top: -1em; margin-bottom: 0.5em;"> | |
🤖 <em>LORAIN is thinking...</em> | |
</div> | |
""", unsafe_allow_html=True) | |
# --- INPUT BAR (floating at bottom) --- | |
st.markdown('<div class="chat-input-bar">', unsafe_allow_html=True) | |
col1, col2 = st.columns([10, 1]) | |
user_input = col1.chat_input("Type your message here...") | |
if col2.button("🗑️", help="Clear Chat", key="clear-chat-bottom"): | |
clear_chat_history() | |
st.markdown('</div>', unsafe_allow_html=True) | |
# --- PROCESS USER INPUT --- | |
if user_input: | |
st.session_state["is_thinking"] = True # start thinking | |
thread_id = get_or_create_thread_id() | |
client.beta.threads.messages.create(thread_id=thread_id, role="user", content=user_input) | |
save_message("user", user_input) | |
with st.spinner("Thinking and typing... 💭"): | |
run = client.beta.threads.runs.create(thread_id=thread_id, assistant_id=assistant_id) | |
while True: | |
run_status = client.beta.threads.runs.retrieve(thread_id=thread_id, run_id=run.id) | |
if run_status.status == "completed": | |
break | |
time.sleep(1) | |
messages_response = client.beta.threads.messages.list(thread_id=thread_id) | |
latest_response = sorted(messages_response.data, key=lambda x: x.created_at)[-1] | |
assistant_message = latest_response.content[0].text.value | |
save_message("assistant", assistant_message) | |
audio_path = synthesize_voice(assistant_message, user_id) | |
if audio_path and os.path.exists(audio_path): | |
st.audio(audio_path, format="audio/mp3", autoplay=True) | |
st.session_state["is_thinking"] = False # stop thinking | |
time.sleep(0.2) | |
st.rerun() | |
# --- Auto-scroll JS --- | |
st.markdown(""" | |
<script> | |
window.onload = function() { | |
var chatWrapper = document.querySelector('.chat-history-wrapper'); | |
if(chatWrapper){ chatWrapper.scrollTop = chatWrapper.scrollHeight; } | |
}; | |
setTimeout(function(){ | |
var chatWrapper = document.querySelector('.chat-history-wrapper'); | |
if(chatWrapper){ chatWrapper.scrollTop = chatWrapper.scrollHeight; } | |
}, 300); | |
</script> | |
""", unsafe_allow_html=True) | |