Spaces:
Sleeping
Sleeping
# app.py (Merged Version - Fixed Chat Input Clearing) | |
import streamlit as st | |
import asyncio | |
import websockets | |
import uuid | |
from datetime import datetime | |
import os | |
import random | |
import time | |
import hashlib | |
# from PIL import Image # Keep commented unless needed for image pasting->3D texture? | |
import glob | |
import base64 | |
import io | |
import streamlit.components.v1 as components | |
import edge_tts | |
# from audio_recorder_streamlit import audio_recorder # Keep commented unless re-adding audio input | |
import nest_asyncio | |
import re | |
import pytz | |
import shutil | |
# import anthropic # Keep commented unless integrating Claude | |
# import openai # Keep commented unless integrating OpenAI | |
from PyPDF2 import PdfReader | |
import threading | |
import json | |
import zipfile | |
# from gradio_client import Client # Keep commented unless integrating ArXiv/Gradio | |
from dotenv import load_dotenv | |
from streamlit_marquee import streamlit_marquee | |
from collections import defaultdict, Counter | |
import pandas as pd | |
from streamlit_js_eval import streamlit_js_eval # Still needed for some UI interactions | |
from PIL import Image # Needed for paste_image_component | |
# 🛠️ Patch asyncio for nesting | |
nest_asyncio.apply() | |
# 🎨 Page Config (From New App) | |
st.set_page_config( | |
page_title="🤖🏗️ Shared World Builder 🏆", | |
page_icon="🏗️", | |
layout="wide", | |
initial_sidebar_state="expanded" # Keep sidebar open initially | |
) | |
# --- Constants (Combined & 3D Added) --- | |
# Chat/User Constants | |
icons = '🤖🏗️🗣️' # Updated icons | |
Site_Name = '🤖🏗️ Shared World Builder 🗣️' | |
START_ROOM = "World Lobby 🌍" | |
FUN_USERNAMES = { # Simplified for clarity, can expand later | |
"BuilderBot 🤖": "en-US-AriaNeural", "WorldWeaver 🕸️": "en-US-JennyNeural", | |
"Terraformer 🌱": "en-GB-SoniaNeural", "SkyArchitect ☁️": "en-AU-NatashaNeural", | |
"PixelPainter 🎨": "en-CA-ClaraNeural", "VoxelVortex 🌪️": "en-US-GuyNeural", | |
"CosmicCrafter ✨": "en-GB-RyanNeural", "GeoGuru 🗺️": "en-AU-WilliamNeural", | |
"BlockBard 🧱": "en-CA-LiamNeural", "SoundSculptor 🔊": "en-US-AnaNeural", | |
} | |
EDGE_TTS_VOICES = list(set(FUN_USERNAMES.values())) | |
FILE_EMOJIS = {"md": "📝", "mp3": "🎵", "png": "🖼️", "mp4": "🎥", "zip": "📦", "csv":"📄", "json": "📄"} | |
# 3D World Constants | |
SAVE_DIR = "saved_worlds" | |
PLOT_WIDTH = 50.0 | |
PLOT_DEPTH = 50.0 | |
CSV_COLUMNS = ['obj_id', 'type', 'pos_x', 'pos_y', 'pos_z', 'rot_x', 'rot_y', 'rot_z', 'rot_order'] | |
WORLD_STATE_FILE = "world_state.json" # Using JSON for simpler in-memory<->disk state | |
# --- Directories (Combined) --- | |
for d in ["chat_logs", "audio_logs", "audio_cache", SAVE_DIR]: # Added SAVE_DIR | |
os.makedirs(d, exist_ok=True) | |
CHAT_DIR = "chat_logs" | |
MEDIA_DIR = "." # Where general files are saved/served from | |
AUDIO_CACHE_DIR = "audio_cache" | |
AUDIO_DIR = "audio_logs" | |
STATE_FILE = "user_state.txt" # For remembering username | |
CHAT_FILE = os.path.join(CHAT_DIR, "global_chat.md") # Used for initial load maybe? | |
# Removed vote files for simplicity | |
# --- API Keys (Keep placeholder logic) --- | |
load_dotenv() | |
# anthropic_key = os.getenv('ANTHROPIC_API_KEY', st.secrets.get('ANTHROPIC_API_KEY', "")) | |
# openai_api_key = os.getenv('OPENAI_API_KEY', st.secrets.get('OPENAI_API_KEY', "")) | |
# openai_client = openai.OpenAI(api_key=openai_api_key) | |
# --- Helper Functions (Combined & Adapted) --- | |
def format_timestamp_prefix(username=""): | |
# Using UTC for consistency in logs/filenames across timezones potentially | |
now = datetime.now(pytz.utc) | |
# Simplified format | |
# Added randomness to avoid rare collisions if multiple users save at exact same second | |
rand_suffix = ''.join(random.choices('abcdefghijklmnopqrstuvwxyz0123456789', k=4)) | |
return f"{now.strftime('%Y%m%d_%H%M%S')}_{username}_{rand_suffix}" | |
# --- Performance Timer (Optional, Keep if desired) --- | |
class PerformanceTimer: | |
def __init__(self, name): self.name, self.start = name, None | |
def __enter__(self): self.start = time.time(); return self | |
def __exit__(self, *args): | |
duration = time.time() - self.start | |
if 'operation_timings' not in st.session_state: st.session_state['operation_timings'] = {} | |
if 'performance_metrics' not in st.session_state: st.session_state['performance_metrics'] = defaultdict(list) | |
st.session_state['operation_timings'][self.name] = duration | |
st.session_state['performance_metrics'][self.name].append(duration) | |
# --- 3D World State Management (Adapted from original + WebSocket focus) --- | |
# Global structure to hold the current state of the world IN MEMORY | |
# Use defaultdict for easier adding | |
# Needs thread safety if accessed by multiple websocket handlers simultaneously. | |
# For now, relying on Streamlit's single-thread-per-session execution | |
# and assuming broadcast updates are okay without strict locking for this scale. | |
# A lock would be needed for production robustness. | |
world_objects_lock = threading.Lock() # Use lock for modifying the global dict | |
world_objects = defaultdict(dict) # Holds {obj_id: object_data} | |
def load_world_state_from_disk(): | |
"""Loads world state from the JSON file or fallback to CSVs.""" | |
global world_objects | |
loaded_count = 0 | |
print(f"[{time.time():.2f}] Attempting to load world state...") | |
# Use lock for reading/writing the global dict | |
with world_objects_lock: | |
if os.path.exists(WORLD_STATE_FILE): | |
try: | |
with open(WORLD_STATE_FILE, 'r') as f: | |
data = json.load(f) | |
# Ensure keys are strings if they got saved as ints somehow | |
world_objects = defaultdict(dict, {str(k): v for k, v in data.items()}) | |
loaded_count = len(world_objects) | |
print(f"Loaded {loaded_count} objects from {WORLD_STATE_FILE}") | |
except json.JSONDecodeError: | |
print(f"Error reading {WORLD_STATE_FILE}. Falling back to CSVs.") | |
world_objects = defaultdict(dict) # Reset before loading from CSV | |
except Exception as e: | |
print(f"Error loading from {WORLD_STATE_FILE}: {e}. Falling back to CSVs.") | |
world_objects = defaultdict(dict) # Reset | |
# Fallback or initial load from CSVs if JSON fails or doesn't exist | |
if not world_objects: | |
print("Loading world state from CSV files...") | |
# Use the cached CSV loading logic, but populate the global dict | |
loaded_from_csv = get_all_world_objects_from_csv() # Gets list | |
for obj in loaded_from_csv: | |
world_objects[obj['obj_id']] = obj | |
loaded_count = len(world_objects) | |
print(f"Loaded {loaded_count} objects from CSVs.") | |
# Save immediately to JSON for next time | |
save_world_state_to_disk_internal() # Call internal save that assumes lock is held | |
return loaded_count | |
def save_world_state_to_disk(): | |
"""Saves the current in-memory world state to a JSON file. Acquires lock.""" | |
with world_objects_lock: | |
return save_world_state_to_disk_internal() | |
def save_world_state_to_disk_internal(): | |
"""Internal save function - assumes lock is already held.""" | |
global world_objects | |
print(f"Saving {len(world_objects)} objects to {WORLD_STATE_FILE}...") | |
try: | |
with open(WORLD_STATE_FILE, 'w') as f: | |
# Convert defaultdict back to regular dict for saving | |
json.dump(dict(world_objects), f, indent=2) | |
print("World state saved successfully.") | |
return True | |
except Exception as e: | |
print(f"Error saving world state to {WORLD_STATE_FILE}: {e}") | |
# Avoid st.error here as it might be called from background thread | |
return False | |
# --- Functions to load from CSVs (kept for initial load/fallback) --- | |
def load_plot_metadata(): | |
"""Scans save dir for plot_X*_Z*.csv, sorts, calculates metadata.""" | |
print(f"[{time.time():.2f}] Loading plot metadata...") | |
plot_files = [] | |
try: | |
plot_files = [f for f in os.listdir(SAVE_DIR) if f.endswith(".csv") and f.startswith("plot_X")] | |
except FileNotFoundError: | |
print(f"Save directory '{SAVE_DIR}' not found during metadata load.") | |
return [] | |
except Exception as e: | |
print(f"Error listing save directory '{SAVE_DIR}': {e}") | |
return [] | |
parsed_plots = [] | |
for filename in plot_files: | |
try: | |
file_path = os.path.join(SAVE_DIR, filename) | |
# Basic check for empty file before parsing name | |
if not os.path.exists(file_path) or os.path.getsize(file_path) <= 2: continue | |
parts = filename[:-4].split('_') | |
grid_x = int(parts[1][1:]) | |
grid_z = int(parts[2][1:]) | |
plot_name = " ".join(parts[3:]) if len(parts) > 3 else f"Plot ({grid_x},{grid_z})" | |
parsed_plots.append({ | |
'id': filename[:-4], 'filename': filename, | |
'grid_x': grid_x, 'grid_z': grid_z, 'name': plot_name, | |
'x_offset': grid_x * PLOT_WIDTH, 'z_offset': grid_z * PLOT_DEPTH | |
}) | |
except Exception as e: | |
print(f"Warning: Error parsing metadata from filename '{filename}': {e}. Skipping.") | |
continue | |
parsed_plots.sort(key=lambda p: (p['grid_x'], p['grid_z'])) | |
print(f"Found {len(parsed_plots)} valid plot files.") | |
return parsed_plots | |
def load_single_plot_objects_relative(filename): | |
"""Loads objects from a specific CSV file, keeping coordinates relative.""" | |
file_path = os.path.join(SAVE_DIR, filename) | |
objects = [] | |
try: | |
if not os.path.exists(file_path) or os.path.getsize(file_path) == 0: return [] | |
df = pd.read_csv(file_path) | |
if df.empty: return [] | |
# Data Cleaning & Defaulting | |
if 'obj_id' not in df.columns or df['obj_id'].isnull().any(): | |
print(f"Warning: Generating missing obj_ids for {filename}") | |
df['obj_id'] = df['obj_id'].fillna(pd.Series([str(uuid.uuid4()) for _ in range(len(df))])) | |
df['obj_id'] = df['obj_id'].astype(str) | |
for col in ['type', 'pos_x', 'pos_y', 'pos_z']: | |
if col not in df.columns: | |
print(f"Warning: CSV '{filename}' missing essential column '{col}'. Skipping file.") | |
return [] | |
for col, default in [('rot_x', 0.0), ('rot_y', 0.0), ('rot_z', 0.0), ('rot_order', 'XYZ')]: | |
if col not in df.columns: df[col] = default | |
df.fillna({'rot_x': 0.0, 'rot_y': 0.0, 'rot_z': 0.0, 'rot_order': 'XYZ'}, inplace=True) | |
for col in ['pos_x', 'pos_y', 'pos_z', 'rot_x', 'rot_y', 'rot_z']: | |
df[col] = pd.to_numeric(df[col], errors='coerce') | |
df.dropna(subset=['pos_x', 'pos_y', 'pos_z'], inplace=True) # Drop rows where essential position is invalid | |
df['type'] = df['type'].astype(str).fillna('Unknown') | |
# Convert valid rows to dicts | |
objects = df[CSV_COLUMNS].to_dict('records') | |
except pd.errors.EmptyDataError: | |
pass # Normal for empty files | |
except FileNotFoundError: | |
pass # Normal if file doesn't exist yet | |
except Exception as e: | |
print(f"Error loading objects from {filename}: {e}") | |
# Optionally raise or return partial data? For now, return empty on error. | |
return [] | |
return objects | |
def get_all_world_objects_from_csv(): | |
"""Loads ALL objects from ALL known plots into world coordinates FROM CSVs.""" | |
print(f"[{time.time():.2f}] Reloading ALL world objects from CSV files...") | |
all_objects = {} # Use dict keyed by obj_id for auto-deduplication during load | |
plots_meta = load_plot_metadata() | |
for plot in plots_meta: | |
relative_objects = load_single_plot_objects_relative(plot['filename']) | |
for obj in relative_objects: | |
obj_id = obj.get('obj_id') | |
if not obj_id: continue # Skip objects that failed ID generation/loading | |
# Convert to world coordinates | |
world_obj = { | |
'obj_id': obj_id, | |
'type': obj.get('type', 'Unknown'), | |
'position': { | |
'x': obj.get('pos_x', 0.0) + plot['x_offset'], | |
'y': obj.get('pos_y', 0.0), | |
'z': obj.get('pos_z', 0.0) + plot['z_offset'] | |
}, | |
'rotation': { | |
'_x': obj.get('rot_x', 0.0), | |
'_y': obj.get('rot_y', 0.0), | |
'_z': obj.get('rot_z', 0.0), | |
'_order': obj.get('rot_order', 'XYZ') | |
} | |
} | |
# If obj_id already exists, this will overwrite. Last plot file read wins. | |
all_objects[obj_id] = world_obj | |
world_list = list(all_objects.values()) | |
print(f"Loaded {len(world_list)} total objects from CSVs.") | |
return world_list | |
# --- Session State Init (Combined & Expanded) --- | |
def init_session_state(): | |
defaults = { | |
# From Chat App | |
'server_running_flag': False, 'server_instance': None, 'server_task': None, | |
'active_connections': defaultdict(dict), # Stores actual websocket objects by ID | |
'last_chat_update': 0, 'message_input': "", 'audio_cache': {}, | |
'tts_voice': "en-US-AriaNeural", 'chat_history': [], 'marquee_settings': { | |
"background": "#1E1E1E", "color": "#FFFFFF", "font-size": "14px", | |
"animationDuration": "20s", "width": "100%", "lineHeight": "35px" | |
}, | |
'enable_audio': True, 'download_link_cache': {}, 'username': None, | |
'autosend': False, # Default autosend off for chat | |
'last_message': "", 'timer_start': time.time(), | |
'last_sent_transcript': "", 'last_refresh': time.time(), | |
'auto_refresh': False, # Auto-refresh for chat display? Maybe not needed with WS | |
'refresh_rate': 30, | |
# From 3D World App (or adapted) | |
'selected_object': 'None', # Current building tool | |
'initial_world_state_loaded': False, # Flag to load state only once | |
# Keep others if needed, removed some for clarity | |
'operation_timings': {}, 'performance_metrics': defaultdict(list), | |
'paste_image_base64': "", # For paste component state tracking | |
} | |
for k, v in defaults.items(): | |
if k not in st.session_state: | |
st.session_state[k] = v | |
# Ensure nested dicts are present | |
if 'marquee_settings' not in st.session_state: st.session_state.marquee_settings = defaults['marquee_settings'] | |
if 'active_connections' not in st.session_state: st.session_state.active_connections = defaultdict(dict) | |
# --- Marquee Helpers (Keep from New App) --- | |
def update_marquee_settings_ui(): # ... (keep function as is) ... | |
pass # Placeholder if not immediately needed | |
def display_marquee(text, settings, key_suffix=""): # ... (keep function as is) ... | |
pass # Placeholder | |
# --- Text & File Helpers (Keep & Adapt from New App) --- | |
def clean_text_for_tts(text): # ... (keep function as is) ... | |
# Remove markdown links but keep the text | |
text = re.sub(r'\[([^\]]+)\]\([^\)]+\)', r'\1', text) | |
# Remove other potential problematic characters for TTS | |
text = re.sub(r'[#*!\[\]]+', '', ' '.join(text.split())) | |
return text[:250] or "No text" # Limit length slightly more | |
def generate_filename(prompt, username, file_type="md", title=None): # ... (keep function as is) ... | |
timestamp = format_timestamp_prefix(username) | |
# Simplified filename generation | |
base = clean_text_for_filename(title if title else prompt[:30]) | |
hash_val = hashlib.md5(prompt.encode()).hexdigest()[:6] | |
# Ensure file type is added correctly | |
filename_base = f"{timestamp}_{base}_{hash_val}" | |
return f"{filename_base}.{file_type}" | |
def clean_text_for_filename(text): # ... (keep function as is) ... | |
# Replace spaces with underscores, remove invalid chars | |
text = re.sub(r'\s+', '_', text) | |
text = re.sub(r'[^\w\-.]', '', text) # Allow underscore, hyphen, period | |
return text[:50] # Limit length | |
def create_file(content, username, file_type="md", save_path=None): # Added explicit save_path | |
if not save_path: | |
filename = generate_filename(content, username, file_type) | |
save_path = os.path.join(MEDIA_DIR, filename) # Save to base dir by default | |
# Ensure directory exists if path includes one | |
dir_name = os.path.dirname(save_path) | |
if dir_name: | |
os.makedirs(dir_name, exist_ok=True) | |
try: | |
with open(save_path, 'w', encoding='utf-8') as f: | |
f.write(content) | |
print(f"Created file: {save_path}") | |
return save_path | |
except Exception as e: | |
print(f"Error creating file {save_path}: {e}") | |
return None | |
def get_download_link(file, file_type="mp3"): # ... (keep function as is, ensure FILE_EMOJIS updated) ... | |
if not file or not os.path.exists(file): return f"File not found: {file}" | |
# Cache based on file path and modification time to handle updates | |
cache_key = f"dl_{file}_{os.path.getmtime(file)}" | |
if cache_key not in st.session_state.get('download_link_cache', {}): | |
with open(file, "rb") as f: b64 = base64.b64encode(f.read()).decode() | |
mime_types = {"mp3": "audio/mpeg", "png": "image/png", "mp4": "video/mp4", "md": "text/markdown", "zip": "application/zip", "csv": "text/csv", "json": "application/json"} | |
link_html = f'<a href="data:{mime_types.get(file_type, "application/octet-stream")};base64,{b64}" download="{os.path.basename(file)}">{FILE_EMOJIS.get(file_type, "📄")} Download {os.path.basename(file)}</a>' | |
# Ensure cache dict exists | |
if 'download_link_cache' not in st.session_state: st.session_state.download_link_cache = {} | |
st.session_state.download_link_cache[cache_key] = link_html | |
return st.session_state.download_link_cache[cache_key] | |
def save_username(username): # ... (keep function as is) ... | |
try: | |
with open(STATE_FILE, 'w') as f: f.write(username) | |
except Exception as e: print(f"Failed to save username: {e}") | |
def load_username(): # ... (keep function as is) ... | |
if os.path.exists(STATE_FILE): | |
try: | |
with open(STATE_FILE, 'r') as f: return f.read().strip() | |
except Exception as e: print(f"Failed to load username: {e}") | |
return None | |
# --- Audio Processing (Keep from New App) --- | |
async def async_edge_tts_generate(text, voice, username): # Simplified args | |
if not text: return None | |
cache_key = hashlib.md5(f"{text[:150]}_{voice}".encode()).hexdigest() # Use hash for cache key | |
# Ensure audio cache dict exists | |
if 'audio_cache' not in st.session_state: st.session_state.audio_cache = {} | |
cached_path = st.session_state.audio_cache.get(cache_key) | |
if cached_path and os.path.exists(cached_path): | |
# print(f"Using cached audio: {cached_path}") | |
return cached_path | |
text_cleaned = clean_text_for_tts(text) | |
if not text_cleaned or text_cleaned == "No text": | |
print("Skipping TTS for empty/cleaned text.") | |
return None | |
filename_base = generate_filename(text_cleaned, username, "mp3") | |
save_path = os.path.join(AUDIO_DIR, filename_base) | |
print(f"Generating TTS audio for '{text_cleaned[:30]}...' to {save_path}") | |
try: | |
communicate = edge_tts.Communicate(text_cleaned, voice) | |
await communicate.save(save_path) | |
if os.path.exists(save_path) and os.path.getsize(save_path) > 0: | |
st.session_state.audio_cache[cache_key] = save_path | |
return save_path | |
else: | |
print(f"Audio file {save_path} failed generation or is empty.") | |
return None | |
except edge_tts.exceptions.NoAudioReceived: | |
print(f"Edge TTS returned no audio for voice {voice}.") | |
return None | |
except Exception as e: | |
print(f"Error during Edge TTS generation: {e}") | |
return None | |
def play_and_download_audio(file_path): # ... (keep function as is) ... | |
if file_path and os.path.exists(file_path): | |
try: | |
st.audio(file_path) | |
file_type = file_path.split('.')[-1] | |
st.markdown(get_download_link(file_path, file_type), unsafe_allow_html=True) | |
except Exception as e: | |
st.error(f"Error displaying audio {file_path}: {e}") | |
else: | |
st.warning(f"Audio file not found for playback: {file_path}") | |
# --- Chat Saving/Loading (Keep & Adapt from New App) --- | |
async def save_chat_entry(username, message, voice, is_markdown=False): | |
"""Saves chat entry to a file and potentially generates audio.""" | |
if not message.strip(): return None, None | |
print(f"Saving chat entry from {username}: {message[:50]}...") | |
central = pytz.timezone('US/Central') # Or use UTC | |
timestamp = datetime.now(central).strftime("%Y-%m-%d %H:%M:%S") | |
entry = f"[{timestamp}] {username} ({voice}): {message}" if not is_markdown else f"[{timestamp}] {username} ({voice}):\n```markdown\n{message}\n```" | |
# Save to individual file in chat_logs | |
md_filename_base = generate_filename(message, username, "md") | |
md_file_path = os.path.join(CHAT_DIR, md_filename_base) | |
md_file = create_file(entry, username, "md", save_path=md_file_path) | |
# Append to session state history for immediate display | |
# Ensure history exists | |
if 'chat_history' not in st.session_state: st.session_state.chat_history = [] | |
st.session_state.chat_history.append(entry) | |
# Generate audio (only if enabled) | |
audio_file = None | |
if st.session_state.get('enable_audio', True): | |
# Use non-markdown message for TTS | |
tts_message = message if not is_markdown else message | |
audio_file = await async_edge_tts_generate(tts_message, voice, username) | |
if audio_file: | |
print(f"Generated audio: {audio_file}") | |
else: | |
print(f"Failed to generate audio for chat message.") | |
return md_file, audio_file | |
async def load_chat_history(): | |
"""Loads chat history from files in CHAT_DIR if session state is empty.""" | |
if 'chat_history' not in st.session_state: st.session_state.chat_history = [] | |
if not st.session_state.chat_history: | |
print("Loading chat history from files...") | |
chat_files = sorted(glob.glob(os.path.join(CHAT_DIR, "*.md")), key=os.path.getmtime) | |
loaded_count = 0 | |
for f_path in chat_files: | |
try: | |
with open(f_path, 'r', encoding='utf-8') as file: | |
st.session_state.chat_history.append(file.read().strip()) | |
loaded_count += 1 | |
except Exception as e: | |
print(f"Error reading chat file {f_path}: {e}") | |
print(f"Loaded {loaded_count} chat entries from files.") | |
return st.session_state.chat_history | |
# --- WebSocket Handling (Adapted for 3D State & Thread Safety) --- | |
# Global set to track connected client IDs for efficient broadcast checks | |
connected_clients = set() # Holds client_id strings | |
async def register_client(websocket): | |
"""Adds client to tracking structures.""" | |
client_id = str(websocket.id) | |
connected_clients.add(client_id) | |
st.session_state.active_connections[client_id] = websocket # Store WS object itself | |
print(f"Client registered: {client_id}. Total: {len(connected_clients)}") | |
async def unregister_client(websocket): | |
"""Removes client from tracking structures.""" | |
client_id = str(websocket.id) | |
connected_clients.discard(client_id) | |
st.session_state.active_connections.pop(client_id, None) | |
print(f"Client unregistered: {client_id}. Remaining: {len(connected_clients)}") | |
async def websocket_handler(websocket, path): | |
await register_client(websocket) | |
client_id = str(websocket.id) | |
username = st.session_state.get('username', f"User_{client_id[:4]}") # Get username associated with this session | |
# Send initial world state to the new client | |
try: | |
with world_objects_lock: # Read lock for initial state | |
initial_state_payload = dict(world_objects) | |
initial_state_msg = json.dumps({ | |
"type": "initial_state", | |
"payload": initial_state_payload # Send current world state | |
}) | |
await websocket.send(initial_state_msg) | |
print(f"Sent initial state ({len(initial_state_payload)} objects) to {client_id}") | |
# Announce join to others | |
await broadcast_message(json.dumps({ | |
"type": "user_join", | |
"payload": {"username": username, "id": client_id} # Send assigned username | |
}), exclude_id=client_id) | |
except Exception as e: | |
print(f"Error during initial phase for {client_id}: {e}") | |
# Main message loop | |
try: | |
async for message in websocket: | |
try: | |
data = json.loads(message) | |
msg_type = data.get("type") | |
payload = data.get("payload", {}) # Ensure payload is a dict | |
# Get username from payload (client should send it), fallback to initial session username | |
sender_username = payload.get("username", username) | |
if msg_type == "chat_message": | |
chat_text = payload.get('message', '') | |
print(f"Received chat from {sender_username} ({client_id}): {chat_text[:50]}...") | |
voice = payload.get('voice', FUN_USERNAMES.get(sender_username, "en-US-AriaNeural")) | |
# Save chat locally (run in background task) | |
asyncio.create_task(save_chat_entry(sender_username, chat_text, voice)) | |
# Broadcast chat message (including sender info) to others | |
await broadcast_message(message, exclude_id=client_id) # Forward original msg | |
elif msg_type == "place_object": | |
obj_data = payload.get("object_data") | |
if obj_data and 'obj_id' in obj_data and 'type' in obj_data: | |
print(f"Received place_object from {sender_username} ({client_id}): {obj_data.get('type')} ({obj_data['obj_id']})") | |
with world_objects_lock: # Lock for write | |
world_objects[obj_data['obj_id']] = obj_data # Add/update in memory | |
# Broadcast placement to others (include who placed it) | |
broadcast_payload = json.dumps({ | |
"type": "object_placed", | |
"payload": {"object_data": obj_data, "username": sender_username} | |
}) | |
await broadcast_message(broadcast_payload, exclude_id=client_id) | |
# Trigger periodic save maybe? Or add to a "dirty" queue | |
else: | |
print(f"Invalid place_object payload from {client_id}: {payload}") | |
elif msg_type == "delete_object": | |
obj_id = payload.get("obj_id") | |
if obj_id: | |
print(f"Received delete_object from {sender_username} ({client_id}): {obj_id}") | |
removed = False | |
with world_objects_lock: # Lock for write | |
if obj_id in world_objects: | |
del world_objects[obj_id] | |
removed = True | |
if removed: | |
# Broadcast deletion | |
broadcast_payload = json.dumps({ | |
"type": "object_deleted", | |
"payload": {"obj_id": obj_id, "username": sender_username} | |
}) | |
await broadcast_message(broadcast_payload, exclude_id=client_id) | |
else: | |
print(f"Invalid delete_object payload from {client_id}: {payload}") | |
elif msg_type == "player_position": | |
# Basic position broadcasting (no server-side validation yet) | |
pos_data = payload.get("position") | |
if pos_data: | |
broadcast_payload = json.dumps({ | |
"type": "player_moved", | |
"payload": {"username": sender_username, "id": client_id, "position": pos_data} | |
}) | |
await broadcast_message(broadcast_payload, exclude_id=client_id) | |
# Add handlers for other types (request_save, etc.) | |
except json.JSONDecodeError: | |
print(f"Received invalid JSON from {client_id}: {message[:100]}...") # Log truncated message | |
except Exception as e: | |
print(f"Error processing message from {client_id}: {e}") | |
# Optionally send error back to client? | |
# await websocket.send(json.dumps({"type": "error", "payload": {"message": str(e)}})) | |
except websockets.ConnectionClosedOK: | |
print(f"Client disconnected normally: {client_id} ({username})") | |
except websockets.ConnectionClosedError as e: | |
print(f"Client connection closed with error: {client_id} ({username}) - {e}") | |
except Exception as e: | |
print(f"Unexpected error in handler for {client_id}: {e}") # Catch broader errors | |
finally: | |
# Announce leave to others | |
await broadcast_message(json.dumps({ | |
"type": "user_leave", | |
"payload": {"username": username, "id": client_id} | |
}), exclude_id=client_id) # Exclude self just in case | |
await unregister_client(websocket) # Ensure cleanup | |
# Modified broadcast to use the global set and skip sender | |
async def broadcast_message(message, exclude_id=None): | |
"""Sends a message to all connected clients except the excluded one.""" | |
if not connected_clients: | |
return # No one to send to | |
# Create list of tasks for sending concurrently | |
tasks = [] | |
# Iterate over client IDs currently known | |
current_client_ids = list(connected_clients) # Copy to avoid modification issues during iteration | |
for client_id in current_client_ids: | |
if client_id == exclude_id: | |
continue | |
websocket = st.session_state.active_connections.get(client_id) # Get WS object from session state dict | |
if websocket: | |
# Schedule the send operation as a task | |
tasks.append(asyncio.create_task(send_safely(websocket, message, client_id))) | |
else: | |
# If websocket object not found, mark client for potential cleanup | |
print(f"Websocket object not found for client {client_id} during broadcast.") | |
# No immediate cleanup here, rely on handler's finally block or periodic checks | |
# Wait for all send tasks to complete (or fail) | |
if tasks: | |
results = await asyncio.gather(*tasks, return_exceptions=True) | |
# Process results to find failed sends and potentially clean up | |
disconnected_after_send = set() | |
for i, result in enumerate(results): | |
if isinstance(result, Exception): | |
# Identify which client failed based on task order (assuming order is preserved) | |
failed_client_id = current_client_ids[i] # This assumes no skips matched exclude_id - needs refinement if exclude_id used | |
# To be robust, the task itself should return the client_id on failure | |
print(f"Error sending message during broadcast: {result}") | |
# Mark for potential cleanup (or let handler handle it) | |
# disconnected_after_send.add(failed_client_id) | |
# Perform cleanup based on failed sends if needed | |
# if disconnected_after_send: | |
# for client_id in disconnected_after_send: | |
# connected_clients.discard(client_id) | |
# st.session_state.active_connections.pop(client_id, None) | |
async def send_safely(websocket, message, client_id): | |
"""Wrapper to send message and handle potential connection errors.""" | |
try: | |
await websocket.send(message) | |
except websockets.ConnectionClosed: | |
print(f"Send failed: Connection closed for client {client_id}") | |
# Don't unregister here, let the main handler loop do it | |
raise # Re-raise exception for gather to catch | |
except RuntimeError as e: # Handle loop closed errors | |
print(f"Send failed: RuntimeError for client {client_id}: {e}") | |
raise | |
except Exception as e: | |
print(f"Send failed: Unexpected error for client {client_id}: {e}") | |
raise | |
async def run_websocket_server(): | |
# Check if already running - basic flag protection | |
if st.session_state.get('server_running_flag', False): | |
print("Server flag indicates already running or starting.") | |
return | |
st.session_state['server_running_flag'] = True | |
print("Attempting to start WebSocket server on 0.0.0.0:8765...") | |
stop_event = asyncio.Event() # For potential graceful shutdown later | |
st.session_state['websocket_stop_event'] = stop_event | |
try: | |
# Use 0.0.0.0 for broader access (requires firewall config) | |
server = await websockets.serve(websocket_handler, "0.0.0.0", 8765) | |
st.session_state['server_instance'] = server # Store server instance | |
print(f"WebSocket server started successfully on {server.sockets[0].getsockname()}.") | |
await stop_event.wait() # Keep server running until stop event is set | |
except OSError as e: | |
print(f"### FAILED TO START WEBSOCKET SERVER: {e}") | |
st.error(f"Failed start WebSocket: {e}. Port 8765 busy?") | |
except Exception as e: | |
print(f"### UNEXPECTED ERROR IN WEBSOCKET SERVER: {e}") | |
st.error(f"WebSocket server error: {e}") | |
finally: | |
print("WebSocket server task loop finished.") | |
if 'server_instance' in st.session_state and st.session_state.server_instance: | |
st.session_state.server_instance.close() | |
await st.session_state.server_instance.wait_closed() | |
print("WebSocket server closed.") | |
st.session_state['server_running_flag'] = False | |
st.session_state['server_instance'] = None | |
st.session_state['websocket_stop_event'] = None | |
def start_websocket_server_thread(): | |
"""Starts the WebSocket server in a separate thread.""" | |
if st.session_state.get('server_task') and st.session_state.server_task.is_alive(): | |
print("Server thread check: Already running.") | |
return | |
if st.session_state.get('server_running_flag', False): | |
print("Server flag check: Already running.") | |
return | |
print("Creating and starting new server thread.") | |
# Ensure a new loop is created and set for this thread | |
def run_server_loop(): | |
loop = asyncio.new_event_loop() | |
asyncio.set_event_loop(loop) | |
try: | |
loop.run_until_complete(run_websocket_server()) | |
finally: | |
loop.close() | |
print("Server thread asyncio loop closed.") | |
st.session_state.server_task = threading.Thread(target=run_server_loop, daemon=True) | |
st.session_state.server_task.start() | |
time.sleep(1) # Give thread a moment to initialize | |
print(f"Server thread started. Alive: {st.session_state.server_task.is_alive()}") | |
# --- PDF to Audio (Keep if desired, maybe in a separate tab?) --- | |
class AudioProcessor: # ... (keep class as is) ... | |
def __init__(self): self.cache_dir=AUDIO_CACHE_DIR; os.makedirs(self.cache_dir,exist_ok=True); self.metadata=json.load(open(f"{self.cache_dir}/metadata.json", 'r')) if os.path.exists(f"{self.cache_dir}/metadata.json") else {} | |
def _save_metadata(self): #... (save logic) ... | |
try: | |
with open(f"{self.cache_dir}/metadata.json", 'w') as f: json.dump(self.metadata, f, indent=2) | |
except Exception as e: print(f"Failed metadata save: {e}") | |
async def create_audio(self, text, voice='en-US-AriaNeural'): # ... (audio creation logic) ... | |
cache_key=hashlib.md5(f"{text[:150]}:{voice}".encode()).hexdigest(); cache_path=f"{self.cache_dir}/{cache_key}.mp3" | |
if cache_key in self.metadata and os.path.exists(cache_path): return cache_path | |
text_cleaned=clean_text_for_tts(text); | |
if not text_cleaned: return None | |
# Ensure dir exists before saving | |
os.makedirs(os.path.dirname(cache_path), exist_ok=True) | |
try: | |
communicate=edge_tts.Communicate(text_cleaned,voice); await communicate.save(cache_path) | |
self.metadata[cache_key]={'timestamp': datetime.now().isoformat(), 'text_length': len(text_cleaned), 'voice': voice}; self._save_metadata() | |
return cache_path | |
except Exception as e: | |
print(f"TTS Create Audio Error: {e}") | |
return None | |
def process_pdf(pdf_file, max_pages, voice, audio_processor): # ... (keep function as is) ... | |
try: | |
reader=PdfReader(pdf_file); total_pages=min(len(reader.pages),max_pages); texts,audios={}, {} | |
page_threads = [] | |
results_lock = threading.Lock() # Lock for updating shared audios dict | |
def process_page_sync(page_num, page_text): | |
# Run the async function in a new event loop for this thread | |
async def run_async_audio(): | |
return await audio_processor.create_audio(page_text, voice) | |
try: | |
audio_path = asyncio.run(run_async_audio()) | |
if audio_path: | |
with results_lock: | |
audios[page_num] = audio_path | |
except Exception as page_e: | |
print(f"Error processing page {page_num+1} audio: {page_e}") | |
for i in range(total_pages): | |
text=reader.pages[i].extract_text() | |
if text: # Only process pages with text | |
texts[i]=text | |
thread = threading.Thread(target=process_page_sync, args=(i, text)) | |
page_threads.append(thread) | |
thread.start() | |
else: texts[i] = "[No text extracted]" | |
# Wait for all threads to complete | |
for thread in page_threads: | |
thread.join() | |
return texts, audios, total_pages | |
except Exception as pdf_e: | |
st.error(f"Error reading PDF: {pdf_e}") | |
return {}, {}, 0 | |
# --- ArXiv/AI Lookup (Commented out for focus) --- | |
# def parse_arxiv_refs(...): pass | |
# def generate_5min_feature_markdown(...): pass | |
# async def create_paper_audio_files(...): pass | |
# async def perform_ai_lookup(...): pass | |
# async def perform_claude_search(...): pass | |
# async def perform_arxiv_search(...): pass | |
# --- Image Handling (Keep basic save, comment out Claude processing) --- | |
async def save_pasted_image(image, username): # Simplified | |
img_hash = hashlib.md5(image.tobytes()).hexdigest()[:8] | |
# Add check against existing hashes if needed: if img_hash in st.session_state.image_hashes: return None | |
timestamp = format_timestamp_prefix(username) | |
filename = f"{timestamp}_pasted_{img_hash}.png" | |
filepath = os.path.join(MEDIA_DIR, filename) # Save in base dir | |
try: | |
image.save(filepath, "PNG") | |
print(f"Pasted image saved: {filepath}") | |
# Optionally announce image paste via chat? | |
# await save_chat_entry(username, f"Pasted an image: {filename}", FUN_USERNAMES.get(username, "en-US-AriaNeural")) | |
return filepath | |
except Exception as e: | |
print(f"Failed image save: {e}") | |
return None | |
# --- Zip and Delete Files (Keep from New App) --- | |
def create_zip_of_files(files, prefix="Archive", query=""): # Simplified args | |
if not files: | |
st.warning("No files selected to zip.") | |
return None | |
timestamp = format_timestamp_prefix("Zip") # Generic timestamp | |
zip_name = f"{prefix}_{timestamp}.zip" | |
try: | |
print(f"Creating zip: {zip_name} with {len(files)} files...") | |
with zipfile.ZipFile(zip_name, 'w', zipfile.ZIP_DEFLATED) as z: # Use compression | |
for f in files: | |
if os.path.exists(f): | |
z.write(f, os.path.basename(f)) # Use basename in archive | |
else: | |
print(f"Skipping non-existent file for zipping: {f}") | |
print("Zip creation successful.") | |
st.success(f"Created {zip_name}") | |
return zip_name | |
except Exception as e: | |
print(f"Zip creation failed: {e}") | |
st.error(f"Zip creation failed: {e}") | |
return None | |
def delete_files(file_patterns, exclude_files=None): # Takes list of patterns | |
# Define core protected files | |
protected = [STATE_FILE, WORLD_STATE_FILE, "app.py", "index.html", "requirements.txt", "README.md"] | |
# Add user-provided exclusions | |
if exclude_files: | |
protected.extend(exclude_files) | |
deleted_count = 0 | |
errors = 0 | |
for pattern in file_patterns: | |
# Expand pattern relative to current directory | |
pattern_path = os.path.join(MEDIA_DIR, pattern) # Assume MEDIA_DIR is current dir '.' | |
print(f"Attempting to delete files matching: {pattern_path}") | |
try: | |
files_to_delete = glob.glob(pattern_path) | |
if not files_to_delete: | |
print(f"No files found for pattern: {pattern}") | |
continue | |
for f_path in files_to_delete: | |
basename = os.path.basename(f_path) | |
if basename not in protected and os.path.isfile(f_path): # Ensure it's a file and not protected | |
try: | |
os.remove(f_path) | |
print(f"Deleted: {f_path}") | |
deleted_count += 1 | |
except Exception as e: | |
print(f"Failed delete {f_path}: {e}") | |
errors += 1 | |
elif os.path.isdir(f_path): | |
print(f"Skipping directory: {f_path}") | |
#else: | |
# print(f"Skipping protected/non-file: {f_path}") | |
except Exception as glob_e: | |
print(f"Error matching pattern {pattern}: {glob_e}") | |
errors += 1 | |
msg = f"Deleted {deleted_count} files." | |
if errors > 0: | |
msg += f" Encountered {errors} errors." | |
st.warning(msg) | |
else: | |
st.success(msg) | |
# Clear relevant caches | |
st.session_state['download_link_cache'] = {} | |
st.session_state['audio_cache'] = {} # Clear audio cache if MP3s deleted | |
# --- Custom Paste Component (Keep from New App) --- | |
def paste_image_component(): # Returns Image object, type string | |
# If PIL.Image not imported, this will fail. Ensure it is. | |
pasted_img = None | |
img_type = None | |
with st.form(key="paste_form"): | |
paste_input = st.text_area("Paste Image Data Here (Ctrl+V)", key="paste_input_area", height=50) | |
submit_button = st.form_submit_button("Paste Image 📋") | |
if submit_button and paste_input and paste_input.startswith('data:image'): | |
try: | |
mime_type = paste_input.split(';')[0].split(':')[1] | |
base64_str = paste_input.split(',')[1] | |
img_bytes = base64.b64decode(base64_str) | |
pasted_img = Image.open(io.BytesIO(img_bytes)) | |
img_type = mime_type.split('/')[1] # e.g., png, jpeg | |
# Show preview immediately | |
st.image(pasted_img, caption=f"Pasted Image ({img_type.upper()})", width=150) | |
# Store base64 temporarily to avoid reprocessing on rerun if only text changed | |
st.session_state.paste_image_base64 = base64_str | |
except ImportError: | |
st.error("Pillow library not installed. Cannot process pasted images.") | |
except Exception as e: | |
st.error(f"Image decode error: {e}") | |
st.session_state.paste_image_base64 = "" # Clear on error | |
elif submit_button: | |
st.warning("No valid image data pasted.") | |
st.session_state.paste_image_base64 = "" # Clear if invalid submit | |
return pasted_img, img_type | |
# --- Mapping Emojis to Primitive Types --- | |
# Ensure these types match the createPrimitiveMesh function keys in index.html | |
PRIMITIVE_MAP = { | |
"🌳": "Tree", "🗿": "Rock", "🏛️": "Simple House", "🌲": "Pine Tree", "🧱": "Brick Wall", | |
"🔵": "Sphere", "📦": "Cube", "🧴": "Cylinder", "🍦": "Cone", "🍩": "Torus", # cylinder emoji changed | |
"🍄": "Mushroom", "🌵": "Cactus", "🔥": "Campfire", "⭐": "Star", "💎": "Gem", | |
"🗼": "Tower", "🚧": "Barrier", "⛲": "Fountain", "🏮": "Lantern", "팻": "Sign Post" # sign post emoji changed | |
# Add more pairs up to ~20 | |
} | |
# --- Main Streamlit Interface --- | |
def main_interface(): | |
# init_session_state() # Called before main_interface | |
# --- Load initial world state ONCE per session --- | |
if not st.session_state.get('initial_world_state_loaded', False): | |
with st.spinner("Loading initial world state..."): | |
load_world_state_from_disk() | |
st.session_state.initial_world_state_loaded = True | |
# --- Username Setup --- | |
saved_username = load_username() | |
# Check if saved username is valid, otherwise pick random | |
if saved_username and saved_username in FUN_USERNAMES: | |
st.session_state.username = saved_username | |
st.session_state.tts_voice = FUN_USERNAMES[saved_username] # Set voice too | |
if not st.session_state.username: | |
# Pick a random available name if possible | |
# This check might be complex if server restarts often, rely on WS join/leave? | |
# For simplicity, just pick random if none saved/valid | |
st.session_state.username = random.choice(list(FUN_USERNAMES.keys())) | |
st.session_state.tts_voice = FUN_USERNAMES[st.session_state.username] | |
save_username(st.session_state.username) | |
# Announce join happens via WebSocket handler when client connects | |
st.title(f"{Site_Name} - User: {st.session_state.username}") | |
# --- Main Content Area --- | |
tab_world, tab_chat, tab_files = st.tabs(["🏗️ World Builder", "🗣️ Chat", "📂 Files & Settings"]) | |
with tab_world: | |
st.header("Shared 3D World") | |
st.caption("Place objects using the sidebar tools. Changes are shared live!") | |
# --- Embed HTML Component for Three.js --- | |
html_file_path = 'index.html' | |
try: | |
with open(html_file_path, 'r', encoding='utf-8') as f: | |
html_template = f.read() | |
# Determine WebSocket URL based on Streamlit server address if possible | |
# Fallback to localhost for local dev | |
# This part is tricky and might need manual configuration depending on deployment | |
try: | |
# Attempt to get server address (might not work reliably in all deployments) | |
from streamlit.web.server.server import Server | |
session_info = Server.get_current().get_session_info(st.runtime.scriptrunner.get_script_run_ctx().session_id) | |
server_host = session_info.ws.stream.request.host.split(':')[0] # Get host without port | |
ws_url = f"ws://{server_host}:8765" | |
print(f"Determined WS URL: {ws_url}") | |
except Exception as e: | |
print(f"Could not determine server host ({e}), defaulting WS URL to localhost.") | |
ws_url = "ws://localhost:8765" | |
js_injection_script = f""" | |
<script> | |
window.USERNAME = {json.dumps(st.session_state.username)}; | |
window.WEBSOCKET_URL = {json.dumps(ws_url)}; | |
window.SELECTED_OBJECT_TYPE = {json.dumps(st.session_state.selected_object)}; // Send current tool | |
window.PLOT_WIDTH = {json.dumps(PLOT_WIDTH)}; // Send constants needed by JS | |
window.PLOT_DEPTH = {json.dumps(PLOT_DEPTH)}; | |
console.log("Streamlit State Injected:", {{ | |
username: window.USERNAME, | |
websocketUrl: window.WEBSOCKET_URL, | |
selectedObject: window.SELECTED_OBJECT_TYPE | |
}}); | |
</script> | |
""" | |
html_content_with_state = html_template.replace('</head>', js_injection_script + '\n</head>', 1) | |
components.html(html_content_with_state, height=700, scrolling=False) | |
except FileNotFoundError: | |
st.error(f"CRITICAL ERROR: Could not find '{html_file_path}'. Ensure it's in the same directory.") | |
except Exception as e: | |
st.error(f"Error loading 3D component: {e}") | |
st.exception(e) # Show traceback | |
with tab_chat: | |
st.header(f"{START_ROOM} Chat") | |
chat_history = asyncio.run(load_chat_history()) # Load history at start of tab render | |
chat_container = st.container(height=500) # Scrollable chat area | |
with chat_container: | |
# Display chat history (most recent at bottom) | |
st.markdown("----\n".join(reversed(chat_history[-50:]))) # Show last 50, use markdown, reversed | |
# Chat Input Area | |
message_value = st.text_input( | |
"Your Message:", | |
key="message_input", # Key links to st.session_state.message_input | |
label_visibility="collapsed" | |
) | |
send_button_clicked = st.button("Send Chat 💬", key="send_chat_button") | |
should_autosend = st.session_state.get('autosend', False) and message_value # Check flag and value | |
# Process if button clicked OR autosend triggered with a valid message | |
if send_button_clicked or should_autosend: | |
message_to_send = message_value # Capture the value from this run | |
if message_to_send.strip() and message_to_send != st.session_state.get('last_message', ''): | |
# Update last message tracker *before* sending/clearing | |
st.session_state.last_message = message_to_send | |
voice = FUN_USERNAMES.get(st.session_state.username, "en-US-AriaNeural") | |
# Send via WebSocket | |
ws_message = json.dumps({ | |
"type": "chat_message", | |
"payload": {"username": st.session_state.username, "message": message_to_send, "voice": voice} | |
}) | |
# Use asyncio.run correctly for async functions called from sync context | |
try: | |
# Ensure loop is available - get current or run in new one if needed | |
loop = asyncio.get_running_loop() | |
loop.create_task(broadcast_message(ws_message)) # Schedule broadcast | |
except RuntimeError: # No running loop | |
asyncio.run(broadcast_message(ws_message)) # Run in new loop (less efficient) | |
except Exception as e: | |
st.error(f"WebSocket broadcast error: {e}") | |
# Save locally (run in background task to avoid blocking UI much) | |
try: | |
loop = asyncio.get_running_loop() | |
loop.create_task(save_chat_entry(st.session_state.username, message_to_send, voice)) | |
except RuntimeError: | |
asyncio.run(save_chat_entry(st.session_state.username, message_to_send, voice)) | |
except Exception as e: | |
st.error(f"Chat save error: {e}") | |
# --- CORRECT WAY TO CLEAR --- | |
st.session_state.message_input = "" | |
# Rerun to clear the input field visually and update the chat display | |
# Short delay might help ensure background tasks started? Unlikely needed. | |
# time.sleep(0.05) | |
st.rerun() | |
# Handle cases where button was clicked but message was empty/repeated | |
elif send_button_clicked and (not message_to_send.strip() or message_to_send == st.session_state.get('last_message', '')): | |
st.toast("Message empty or same as last.") # Give feedback | |
with tab_files: | |
st.header("File Management & Settings") | |
st.subheader("Server & World State") | |
col_ws, col_save = st.columns(2) | |
with col_ws: | |
# Check thread status if task exists | |
server_alive = st.session_state.get('server_task') and st.session_state.server_task.is_alive() | |
ws_status = "Running" if server_alive else "Stopped" | |
st.metric("WebSocket Server", ws_status) | |
st.metric("Connected Clients", len(connected_clients)) # Use global set length | |
if not server_alive and st.button("Restart Server Thread", key="restart_ws"): | |
start_websocket_server_thread() | |
st.rerun() | |
with col_save: | |
if st.button("💾 Save World State to Disk", key="save_world_disk", help="Saves the current live world state to world_state.json"): | |
with st.spinner("Saving..."): | |
if save_world_state_to_disk(): | |
st.success("World state saved!") | |
else: | |
st.error("Failed to save world state.") | |
st.markdown(get_download_link(WORLD_STATE_FILE, "json"), unsafe_allow_html=True) | |
# File deletion buttons | |
st.subheader("Delete Files") | |
st.warning("Deletion is permanent!", icon="⚠️") | |
col_del1, col_del2, col_del3, col_del4 = st.columns(4) | |
with col_del1: | |
if st.button("🗑️ Chats (.md)", key="del_chat_md"): | |
delete_files([os.path.join(CHAT_DIR, "*.md")]) | |
st.session_state.chat_history = [] # Clear session history too | |
st.rerun() | |
with col_del2: | |
if st.button("🗑️ Audio (.mp3)", key="del_audio_mp3"): | |
delete_files([os.path.join(AUDIO_DIR, "*.mp3"), os.path.join(AUDIO_CACHE_DIR, "*.mp3")]) | |
st.session_state.audio_cache = {} | |
st.rerun() | |
with col_del3: | |
if st.button("🗑️ Zips (.zip)", key="del_zips"): | |
delete_files(["*.zip"]) | |
st.rerun() | |
with col_del4: | |
if st.button("🗑️ All Generated", key="del_all_gen", help="Deletes Chats, Audio, Zips"): | |
delete_files([os.path.join(CHAT_DIR, "*.md"), | |
os.path.join(AUDIO_DIR, "*.mp3"), | |
os.path.join(AUDIO_CACHE_DIR, "*.mp3"), | |
"*.zip"]) | |
st.session_state.chat_history = [] | |
st.session_state.audio_cache = {} | |
st.rerun() | |
# Display Zips | |
st.subheader("📦 Download Archives") | |
zip_files = sorted(glob.glob("*.zip"), key=os.path.getmtime, reverse=True) | |
for zip_file in zip_files: | |
st.markdown(get_download_link(zip_file, "zip"), unsafe_allow_html=True) | |
# --- Sidebar Controls --- | |
with st.sidebar: | |
st.header("🏗️ Build Tools") | |
st.caption("Select an object to place.") | |
# --- Emoji Buttons for Primitives --- | |
cols = st.columns(5) # 5 columns for buttons | |
col_idx = 0 | |
current_tool = st.session_state.get('selected_object', 'None') | |
for emoji, name in PRIMITIVE_MAP.items(): | |
button_key = f"primitive_{name}" | |
# Use primary styling for the selected button | |
button_type = "primary" if current_tool == name else "secondary" | |
if cols[col_idx % 5].button(emoji, key=button_key, help=name, type=button_type, use_container_width=True): | |
st.session_state.selected_object = name | |
# Update JS immediately without full rerun if possible | |
try: | |
js_update_selection = f"updateSelectedObjectType({json.dumps(name)});" | |
streamlit_js_eval(js_code=js_update_selection, key=f"update_tool_js_{name}") # Unique key per button might help | |
except Exception as e: | |
print(f"Could not push tool update to JS: {e}") | |
# Force a rerun to update button styles immediately if JS update fails or isn't enough | |
st.rerun() | |
col_idx += 1 | |
# Button to clear selection | |
st.markdown("---") # Separator | |
if st.button("🚫 Clear Tool", key="clear_tool", use_container_width=True): | |
if st.session_state.selected_object != 'None': | |
st.session_state.selected_object = 'None' | |
try: # Update JS too | |
streamlit_js_eval(js_code=f"updateSelectedObjectType('None');", key="update_tool_js_none") | |
except Exception: pass | |
st.rerun() # Rerun to update UI | |
st.markdown("---") | |
st.header("🗣️ Voice & User") | |
# Username/Voice Selection | |
# Use format_func to display only the name part | |
current_username = st.session_state.get('username', list(FUN_USERNAMES.keys())[0]) | |
username_options = list(FUN_USERNAMES.keys()) | |
try: | |
current_index = username_options.index(current_username) | |
except ValueError: | |
current_index = 0 # Default to first if saved name invalid | |
new_username = st.selectbox( | |
"Change Name/Voice", | |
options=username_options, | |
index=current_index, | |
key="username_select", | |
format_func=lambda x: x.split(" ")[0] # Show only name before emoji | |
) | |
if new_username != st.session_state.username: | |
old_username = st.session_state.username | |
# Announce name change via WebSocket | |
change_msg = json.dumps({ | |
"type":"user_rename", | |
"payload": {"old_username": old_username, "new_username": new_username} | |
}) | |
try: | |
loop = asyncio.get_running_loop() | |
loop.create_task(broadcast_message(change_msg)) | |
except RuntimeError: asyncio.run(broadcast_message(change_msg)) | |
except Exception as e: st.error(f"Rename broadcast error: {e}") | |
st.session_state.username = new_username | |
st.session_state.tts_voice = FUN_USERNAMES[new_username] | |
save_username(st.session_state.username) # Save new username | |
st.rerun() | |
# Enable/Disable Audio Toggle | |
st.session_state['enable_audio'] = st.toggle("Enable TTS Audio", value=st.session_state.get('enable_audio', True)) | |
st.markdown("---") | |
st.info("Chat and File management in main tabs.") | |
# --- Main Execution --- | |
if __name__ == "__main__": | |
# Initialize session state variables first | |
init_session_state() | |
# Start WebSocket server in a thread IF it's not already running | |
# This check needs to be robust across reruns | |
if 'server_task' not in st.session_state or not st.session_state.server_task.is_alive(): | |
print("Main thread: Starting WebSocket server thread...") | |
start_websocket_server_thread() | |
# Wait briefly to allow the server thread to initialize. | |
# This might not be strictly necessary but can help avoid race conditions on first load. | |
time.sleep(1.5) | |
else: | |
print("Main thread: Server thread already exists.") | |
# Load world state from disk if not already loaded this session | |
if not st.session_state.get('initial_world_state_loaded', False): | |
load_world_state_from_disk() | |
st.session_state.initial_world_state_loaded = True | |
# Run the main UI rendering function | |
main_interface() | |
# Optional: Add a periodic save task? | |
# Example: Save every 5 minutes if needed | |
# last_save_time = st.session_state.get('last_world_save_time', 0) | |
# if time.time() - last_save_time > 300: # 300 seconds = 5 minutes | |
# print("Periodic save triggered...") | |
# if save_world_state_to_disk(): | |
# st.session_state.last_world_save_time = time.time() |