Spaces:
Running
Running
import streamlit as st | |
import asyncio | |
import edge_tts | |
import time | |
import os | |
import uuid | |
import re | |
import html | |
import firebase_admin | |
from firebase_admin import credentials, firestore | |
from openai import OpenAI | |
# ---- Firebase setup ---- | |
if not firebase_admin._apps: | |
cred = credentials.Certificate("firebase-service-account.json") | |
firebase_admin.initialize_app(cred) | |
db = firestore.client() | |
# ---- OpenAI setup ---- | |
openai_key = os.getenv("openai_key") | |
assistant_id = os.getenv("assistant_id") | |
client = OpenAI(api_key=openai_key) | |
VOICE_OPTIONS = { | |
"Jenny (US, Female)": "en-US-JennyNeural", | |
"Aria (US, Female)": "en-US-AriaNeural", | |
"Ryan (UK, Male)": "en-GB-RyanNeural", | |
"Natasha (AU, Female)": "en-AU-NatashaNeural", | |
"William (AU, Male)": "en-AU-WilliamNeural", | |
"Libby (UK, Female)": "en-GB-LibbyNeural", | |
"Leah (SA, Female)": "en-ZA-LeahNeural", | |
"Luke (SA, Male)": "en-ZA-LukeNeural" | |
} | |
st.set_page_config(page_title="C2 Group AI Assistant", layout="wide") | |
# --- State setup | |
if "user_id" not in st.session_state: | |
st.session_state["user_id"] = str(uuid.uuid4()) | |
user_id = st.session_state["user_id"] | |
if "mute_voice" not in st.session_state: | |
st.session_state["mute_voice"] = False | |
if "last_tts_text" not in st.session_state: | |
st.session_state["last_tts_text"] = "" | |
if "last_audio_path" not in st.session_state: | |
st.session_state["last_audio_path"] = "" | |
if "selected_voice" not in st.session_state: | |
st.session_state["selected_voice"] = "Jenny (US, Female)" | |
# --- CSS --- | |
st.markdown(""" | |
<style> | |
.block-container {padding-top: 1rem;} | |
header {visibility: hidden;} | |
.logo-mini { | |
width: 75px !important; | |
margin: 0 auto 0.25em auto; | |
display: block; | |
} | |
.lor-brand-bar { | |
width: 100vw; | |
text-align: center; | |
background: none; | |
margin-bottom: 0.5em; | |
margin-top: 0.1em; | |
position: relative; | |
} | |
.clear-chat-btn-top { | |
position: absolute; | |
top: 10px; | |
right: 50px; | |
font-size: 1.4em; | |
color: #ccc; | |
background: none; | |
border: none; | |
cursor: pointer; | |
z-index: 1000; | |
transition: color 0.2s ease; | |
} | |
.clear-chat-btn-top:hover { | |
color: #fff; | |
} | |
.stChatMessage { max-width: 85%; border-radius: 12px; padding: 8px; margin-bottom: 10px; } | |
.stChatMessage[data-testid="stChatMessage-user"] { background: #f0f0f0; color: #000000; } | |
.stChatMessage[data-testid="stChatMessage-assistant"] { background: #e3f2fd; color: #000000; } | |
.chat-history-wrapper { | |
margin-top: 0.5em; | |
padding-bottom: 9em; | |
min-height: 60vh; | |
} | |
.input-bottom-bar { | |
position: fixed; | |
bottom: 3.5em; | |
width: 100%; | |
background: #191b22; | |
padding: 0.5em 0.6em; | |
border-top: 1px solid #22232c; | |
z-index: 999; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
# --- Top Branding + clear button --- | |
st.markdown(""" | |
<div class="lor-brand-bar"> | |
<img src="https://i0.wp.com/c2group.co.za/wp-content/uploads/2022/10/cropped-1_C2-Group-Technologies-Logo-2-1.png?w=1024&ssl=1" class="logo-mini" /> | |
<div style="font-size: 13px; color: #888;">Powered by C2 Group</div> | |
<button class="clear-chat-btn-top" onclick="window.location.href='?clear=1'">ποΈ</button> | |
</div> | |
""", unsafe_allow_html=True) | |
# --- Sidebar: voice settings --- | |
with st.sidebar: | |
st.markdown("### Voice Settings & Controls") | |
selected_voice = st.selectbox( | |
"Select assistant voice", list(VOICE_OPTIONS.keys()), | |
index=list(VOICE_OPTIONS.keys()).index(st.session_state["selected_voice"]) | |
) | |
st.session_state["selected_voice"] = selected_voice | |
last_audio = st.session_state.get("last_audio_path") | |
mute_voice = st.session_state.get("mute_voice", False) | |
if last_audio and os.path.exists(last_audio): | |
st.audio(last_audio, format="audio/mp3", autoplay=not mute_voice) | |
if st.button("π Replay Voice"): | |
st.audio(last_audio, format="audio/mp3", autoplay=True) | |
if not mute_voice: | |
if st.button("π Mute Voice"): | |
st.session_state["mute_voice"] = True | |
st.rerun() | |
else: | |
if st.button("π Unmute Voice"): | |
st.session_state["mute_voice"] = False | |
st.rerun() | |
# --- Firestore helpers --- | |
def get_or_create_thread_id(): | |
doc_ref = db.collection("users").document(user_id) | |
doc = doc_ref.get() | |
if doc.exists: | |
return doc.to_dict()["thread_id"] | |
else: | |
thread = client.beta.threads.create() | |
doc_ref.set({"thread_id": thread.id, "created_at": firestore.SERVER_TIMESTAMP}) | |
return thread.id | |
def save_message(role, content): | |
db.collection("users").document(user_id).collection("messages").add({ | |
"role": role, | |
"content": content, | |
"timestamp": firestore.SERVER_TIMESTAMP | |
}) | |
def clear_chat_history(): | |
user_doc_ref = db.collection("users").document(user_id) | |
for msg in user_doc_ref.collection("messages").stream(): | |
msg.reference.delete() | |
user_doc_ref.delete() | |
st.session_state.clear() | |
st.rerun() | |
def display_chat_history(): | |
messages = db.collection("users").document(user_id).collection("messages").order_by("timestamp").stream() | |
assistant_icon_html = "<img src='https://img.freepik.com/free-vector/graident-ai-robot-vectorart_78370-4114.jpg?semt=ais_hybrid&w=740' width='22' style='vertical-align:middle; border-radius:50%;'/>" | |
chat_msgs = [] | |
for msg in list(messages)[::-1]: | |
data = msg.to_dict() | |
if data["role"] == "user": | |
chat_msgs.append( | |
f"<div class='stChatMessage' data-testid='stChatMessage-user'>π€ <strong>You:</strong> {data['content']}</div>" | |
) | |
else: | |
chat_msgs.append( | |
f"<div class='stChatMessage' data-testid='stChatMessage-assistant'>{assistant_icon_html} <strong>C2 Assistant:</strong> {data['content']}</div>" | |
) | |
st.markdown('<div class="chat-history-wrapper">' + "".join(chat_msgs) + '</div>', unsafe_allow_html=True) | |
st.markdown('<div id="chat-top-anchor"></div>', unsafe_allow_html=True) | |
# --- TTS sanitization --- | |
def sanitize_for_tts(text): | |
text = html.unescape(text) | |
text = re.sub(r'[^\x00-\x7F]+', ' ', text) | |
text = re.sub(r'\[([^\]]+)\]\([^\)]+\)', r'\1', text) | |
text = re.sub(r'(\*\*|__)(.*?)\1', r'\2', text) | |
text = re.sub(r'(\*|_)(.*?)\1', r'\2', text) | |
text = re.sub(r'^#{1,6}\s+', '', text, flags=re.MULTILINE) | |
text = re.sub(r'^\s*[-*+]\s+', ' β’ ', text, flags=re.MULTILINE) | |
text = re.sub(r'^\s*\d+\.\s+', ' β’ ', text, flags=re.MULTILINE) | |
text = re.sub(r'[!?]{2,}', '.', text) | |
text = re.sub(r'\.{3,}', '.', text) | |
text = re.sub(r'\n{2,}', '. ', text) | |
text = re.sub(r'\s+', ' ', text).strip() | |
return text | |
# --- Edge TTS synth --- | |
async def edge_tts_synthesize(text, voice, user_id): | |
out_path = f"output_{user_id}.mp3" | |
communicate = edge_tts.Communicate(text, voice) | |
await communicate.save(out_path) | |
return out_path | |
def synthesize_voice(text, voice_key, user_id): | |
voice = VOICE_OPTIONS[voice_key] | |
out_path = f"output_{user_id}.mp3" | |
if st.session_state["last_tts_text"] != text or not os.path.exists(out_path) or st.session_state.get("last_voice") != voice: | |
with st.spinner(f"Generating voice ({voice_key})..."): | |
asyncio.run(edge_tts_synthesize(text, voice, user_id)) | |
st.session_state["last_tts_text"] = text | |
st.session_state["last_audio_path"] = out_path | |
st.session_state["last_voice"] = voice | |
return out_path | |
# --- CHAT DISPLAY --- | |
display_chat_history() | |
# --- Bottom chat input --- | |
with st.container(): | |
st.markdown('<div class="input-bottom-bar">', unsafe_allow_html=True) | |
user_input = st.chat_input("Type your message here...") | |
st.markdown('</div>', unsafe_allow_html=True) | |
# --- JS auto-scroll --- | |
st.markdown(""" | |
<script> | |
window.onload = function() { | |
var anchor = document.getElementById("chat-top-anchor"); | |
if(anchor){ anchor.scrollIntoView({ behavior: "smooth", block: "start" }); } | |
}; | |
</script> | |
""", unsafe_allow_html=True) | |
# --- Handle clear button --- | |
if st.query_params.get("clear") == "1": | |
clear_chat_history() | |
# --- Handle user input --- | |
if user_input: | |
thread_id = get_or_create_thread_id() | |
client.beta.threads.messages.create(thread_id=thread_id, role="user", content=user_input) | |
save_message("user", user_input) | |
with st.spinner("Thinking and typing... π"): | |
run = client.beta.threads.runs.create(thread_id=thread_id, assistant_id=assistant_id) | |
while True: | |
run_status = client.beta.threads.runs.retrieve(thread_id=thread_id, run_id=run.id) | |
if run_status.status == "completed": | |
break | |
time.sleep(1) | |
messages_response = client.beta.threads.messages.list(thread_id=thread_id) | |
latest_response = sorted(messages_response.data, key=lambda x: x.created_at)[-1] | |
assistant_message = latest_response.content[0].text.value | |
save_message("assistant", assistant_message) | |
mute_voice = st.session_state.get("mute_voice", False) | |
audio_path = None | |
if not mute_voice and assistant_message.strip(): | |
clean_text = sanitize_for_tts(assistant_message) | |
audio_path = synthesize_voice(clean_text, st.session_state["selected_voice"], user_id) | |
st.session_state["last_audio_path"] = audio_path | |
time.sleep(0.2) | |
st.rerun() | |