|
import streamlit as st |
|
import httpx |
|
import os |
|
import io |
|
from dotenv import load_dotenv |
|
import logging |
|
import asyncio |
|
from streamlit_mic_recorder import mic_recorder |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" |
|
) |
|
logger = logging.getLogger(__name__) |
|
load_dotenv() |
|
|
|
ORCHESTRATOR_URL = os.getenv("ORCHESTRATOR_URL") |
|
|
|
|
|
if "processing_state" not in st.session_state: |
|
st.session_state.processing_state = "initial" |
|
if "orchestrator_response" not in st.session_state: |
|
st.session_state.orchestrator_response = None |
|
if "audio_bytes_input" not in st.session_state: |
|
st.session_state.audio_bytes_input = None |
|
if "audio_filename" not in st.session_state: |
|
st.session_state.audio_filename = None |
|
if "audio_filetype" not in st.session_state: |
|
st.session_state.audio_filetype = None |
|
if "last_audio_source" not in st.session_state: |
|
st.session_state.last_audio_source = None |
|
if "current_recording_id" not in st.session_state: |
|
st.session_state.current_recording_id = None |
|
|
|
|
|
async def call_orchestrator(audio_bytes: bytes, filename: str, content_type: str): |
|
|
|
url = f"{ORCHESTRATOR_URL}/market_brief" |
|
files = {"audio": (filename, audio_bytes, content_type)} |
|
logger.info( |
|
f"Calling orchestrator at {url} with audio file: {filename} ({content_type})" |
|
) |
|
try: |
|
async with httpx.AsyncClient() as client: |
|
response = await client.post(url, files=files, timeout=180.0) |
|
response.raise_for_status() |
|
logger.info(f"Orchestrator returned status {response.status_code}.") |
|
return response.json() |
|
except httpx.RequestError as e: |
|
error_msg = f"HTTP Request failed: {e}" |
|
logger.error(error_msg) |
|
return { |
|
"status": "error", |
|
"message": "Error communicating with orchestrator.", |
|
"errors": [error_msg], |
|
"transcript": None, |
|
"brief": None, |
|
"audio": None, |
|
} |
|
except Exception as e: |
|
error_msg = f"An unexpected error occurred: {e}" |
|
logger.error(error_msg) |
|
return { |
|
"status": "error", |
|
"message": "An unexpected error occurred.", |
|
"errors": [error_msg], |
|
"transcript": None, |
|
"brief": None, |
|
"audio": None, |
|
} |
|
|
|
|
|
st.set_page_config(layout="wide") |
|
st.title("📈 AI Financial Assistant - Morning Market Brief") |
|
st.markdown( |
|
"Ask your query verbally (e.g., 'What's our risk exposure in Asia tech stocks today, and highlight any earnings surprises?') " |
|
"or upload an audio file." |
|
) |
|
|
|
input_method = st.radio( |
|
"Choose input method:", |
|
("Record Audio", "Upload File"), |
|
horizontal=True, |
|
index=0, |
|
key="input_method_radio", |
|
) |
|
|
|
audio_data_ready = False |
|
|
|
|
|
if st.session_state.audio_bytes_input is not None: |
|
audio_data_ready = True |
|
|
|
|
|
if input_method == "Record Audio": |
|
st.subheader("Record Your Query") |
|
|
|
if st.session_state.last_audio_source == "uploader": |
|
st.session_state.audio_bytes_input = None |
|
st.session_state.audio_filename = None |
|
st.session_state.audio_filetype = None |
|
st.session_state.last_audio_source = "recorder" |
|
audio_data_ready = False |
|
|
|
audio_info = mic_recorder( |
|
start_prompt="⏺️ Start Recording", |
|
stop_prompt="⏹️ Stop Recording", |
|
just_once=False, |
|
use_container_width=True, |
|
format="wav", |
|
key="mic_recorder_widget", |
|
) |
|
|
|
if audio_info and audio_info.get("bytes"): |
|
|
|
if st.session_state.current_recording_id != audio_info.get("id"): |
|
st.session_state.current_recording_id = audio_info.get("id") |
|
st.success("Recording complete! Click 'Generate Market Brief' below.") |
|
st.session_state.audio_bytes_input = audio_info["bytes"] |
|
st.session_state.audio_filename = f"live_recording_{audio_info['id']}.wav" |
|
st.session_state.audio_filetype = "audio/wav" |
|
st.session_state.last_audio_source = "recorder" |
|
audio_data_ready = True |
|
st.session_state.processing_state = "initial" |
|
st.session_state.orchestrator_response = None |
|
st.audio(audio_info["bytes"]) |
|
|
|
elif st.session_state.audio_bytes_input: |
|
audio_data_ready = True |
|
st.audio(st.session_state.audio_bytes_input) |
|
|
|
elif ( |
|
st.session_state.last_audio_source == "recorder" |
|
and st.session_state.audio_bytes_input |
|
): |
|
st.markdown("Using last recording:") |
|
st.audio(st.session_state.audio_bytes_input) |
|
audio_data_ready = True |
|
|
|
|
|
elif input_method == "Upload File": |
|
st.subheader("Upload Audio File") |
|
|
|
if st.session_state.last_audio_source == "recorder": |
|
st.session_state.audio_bytes_input = None |
|
st.session_state.audio_filename = None |
|
st.session_state.audio_filetype = None |
|
st.session_state.last_audio_source = "uploader" |
|
st.session_state.current_recording_id = None |
|
audio_data_ready = False |
|
|
|
if "uploaded_file_state" not in st.session_state: |
|
st.session_state.uploaded_file_state = None |
|
|
|
uploaded_file = st.file_uploader( |
|
"Select Audio File", |
|
type=["wav", "mp3", "m4a", "ogg", "flac"], |
|
key="file_uploader_key", |
|
) |
|
|
|
if uploaded_file is not None: |
|
if st.session_state.uploaded_file_state != uploaded_file: |
|
st.session_state.uploaded_file_state = uploaded_file |
|
st.session_state.audio_bytes_input = uploaded_file.getvalue() |
|
st.session_state.audio_filename = uploaded_file.name |
|
st.session_state.audio_filetype = uploaded_file.type |
|
st.session_state.last_audio_source = "uploader" |
|
audio_data_ready = True |
|
st.session_state.processing_state = "initial" |
|
st.session_state.orchestrator_response = None |
|
st.success(f"File '{uploaded_file.name}' ready.") |
|
st.audio( |
|
st.session_state.audio_bytes_input, |
|
format=st.session_state.audio_filetype, |
|
) |
|
elif st.session_state.audio_bytes_input: |
|
audio_data_ready = True |
|
st.audio( |
|
st.session_state.audio_bytes_input, |
|
format=st.session_state.audio_filetype, |
|
) |
|
|
|
elif ( |
|
st.session_state.last_audio_source == "uploader" |
|
and st.session_state.audio_bytes_input |
|
): |
|
st.markdown("Using last uploaded file:") |
|
st.audio( |
|
st.session_state.audio_bytes_input, format=st.session_state.audio_filetype |
|
) |
|
audio_data_ready = True |
|
|
|
|
|
st.divider() |
|
button_disabled = ( |
|
not audio_data_ready or st.session_state.processing_state == "processing" |
|
) |
|
|
|
if st.button( |
|
"Generate Market Brief", |
|
disabled=button_disabled, |
|
type="primary", |
|
use_container_width=True, |
|
key="generate_button", |
|
): |
|
if st.session_state.audio_bytes_input: |
|
st.session_state.processing_state = "processing" |
|
st.session_state.orchestrator_response = None |
|
logger.info( |
|
f"Generate Market Brief button clicked. Source: {st.session_state.last_audio_source}, Filename: {st.session_state.audio_filename}" |
|
) |
|
st.rerun() |
|
else: |
|
st.warning("Please record or upload an audio query first.") |
|
|
|
|
|
if st.session_state.processing_state == "processing": |
|
if ( |
|
st.session_state.audio_bytes_input |
|
and st.session_state.audio_filename |
|
and st.session_state.audio_filetype |
|
): |
|
with st.spinner("Processing your request... This may take a moment. 🤖"): |
|
|
|
logger.info( |
|
f"Calling orchestrator with filename: {st.session_state.audio_filename}, type: {st.session_state.audio_filetype}, bytes: {len(st.session_state.audio_bytes_input)}" |
|
) |
|
try: |
|
response = asyncio.run( |
|
call_orchestrator( |
|
st.session_state.audio_bytes_input, |
|
st.session_state.audio_filename, |
|
st.session_state.audio_filetype, |
|
) |
|
) |
|
st.session_state.orchestrator_response = response |
|
|
|
is_successful_response = True |
|
if not response: |
|
is_successful_response = False |
|
elif ( |
|
response.get("status") == "error" |
|
or response.get("status") == "failed" |
|
): |
|
is_successful_response = False |
|
elif response.get("errors") and len(response.get("errors")) > 0: |
|
is_successful_response = False |
|
|
|
st.session_state.processing_state = ( |
|
"completed" if is_successful_response else "error" |
|
) |
|
|
|
except Exception as e: |
|
logger.error( |
|
f"Error during orchestrator call in Streamlit: {e}", exc_info=True |
|
) |
|
st.session_state.orchestrator_response = { |
|
"status": "error", |
|
"message": f"Streamlit failed to call orchestrator: {str(e)}", |
|
"errors": [str(e)], |
|
"transcript": None, |
|
"brief": None, |
|
"audio": None, |
|
} |
|
st.session_state.processing_state = "error" |
|
st.rerun() |
|
else: |
|
st.error("Audio data is missing for processing. Please record or upload again.") |
|
st.session_state.processing_state = "initial" |
|
|
|
|
|
if st.session_state.processing_state in ["completed", "error"]: |
|
|
|
response = st.session_state.orchestrator_response |
|
st.subheader("📝 Results") |
|
|
|
if response is None: |
|
st.error("No response received from the orchestrator.") |
|
|
|
elif ( |
|
response.get("status") == "failed" |
|
or response.get("status") == "error" |
|
or (response.get("errors") and len(response.get("errors")) > 0) |
|
): |
|
st.error( |
|
f"Workflow {response.get('status', 'failed')}: {response.get('message', 'Check errors below.')}" |
|
) |
|
if response.get("errors"): |
|
st.warning("Details of Errors:") |
|
for i, err in enumerate(response["errors"]): |
|
st.markdown(f"`Error {i+1}`: {err}") |
|
if response.get("warnings"): |
|
st.warning("Details of Warnings:") |
|
for i, warn in enumerate(response["warnings"]): |
|
st.markdown(f"`Warning {i+1}`: {warn}") |
|
|
|
if response.get("transcript"): |
|
st.markdown("---") |
|
st.markdown("Transcript (despite errors):") |
|
st.caption(response.get("transcript")) |
|
if response.get("brief"): |
|
st.markdown("---") |
|
st.markdown("Generated Brief (despite errors):") |
|
st.caption(response.get("brief")) |
|
else: |
|
st.success(response.get("message", "Market brief generated successfully!")) |
|
if response.get("transcript"): |
|
st.markdown("---") |
|
st.markdown("Your Query (Transcript):") |
|
st.caption(response.get("transcript")) |
|
else: |
|
st.info("Transcript not available.") |
|
|
|
if response.get("brief"): |
|
st.markdown("---") |
|
st.markdown("Generated Brief:") |
|
st.write(response.get("brief")) |
|
else: |
|
st.info("Brief text not available.") |
|
|
|
audio_hex = response.get("audio") |
|
if audio_hex: |
|
st.markdown("---") |
|
st.markdown("Audio Brief:") |
|
try: |
|
if not isinstance(audio_hex, str) or not all( |
|
c in "0123456789abcdefABCDEF" for c in audio_hex |
|
): |
|
raise ValueError("Invalid hex string for audio.") |
|
audio_bytes_output = bytes.fromhex(audio_hex) |
|
st.audio(audio_bytes_output, format="audio/mpeg") |
|
except ValueError as ve: |
|
st.error(f"⚠️ Failed to decode audio data: {ve}") |
|
except Exception as e: |
|
st.error(f"⚠️ Failed to play audio: {e}") |
|
else: |
|
st.info("Audio brief not available.") |
|
|
|
if response.get("warnings"): |
|
st.markdown("---") |
|
st.warning("Process Warnings:") |
|
for i, warn in enumerate(response["warnings"]): |
|
st.markdown(f"`Warning {i+1}`: {warn}") |
|
|