Spaces:
Running
Running
import streamlit as st | |
import streamlit.components.v1 as components | |
import yfinance as yf | |
import pandas as pd | |
import numpy as np | |
from datetime import datetime, timedelta | |
from keras.models import load_model | |
from sklearn.preprocessing import MinMaxScaler | |
import time | |
import os | |
import torch.nn as nn | |
import torch | |
# --- Page Configuration --- | |
st.set_page_config(layout="wide") | |
# --- Streamlit Session State Initialization --- | |
if 'run_button_clicked' not in st.session_state: | |
st.session_state.run_button_clicked = False | |
if 'loading' not in st.session_state: | |
st.session_state.loading = False | |
if 'data' not in st.session_state: | |
st.session_state.data = None | |
if 'predictions' not in st.session_state: | |
st.session_state.predictions = None | |
if 'error' not in st.session_state: | |
st.session_state.error = None | |
if 'last_ticker' not in st.session_state: | |
st.session_state['last_ticker'] = 'AMZN' | |
# --- Custom CSS --- | |
st.markdown(""" | |
<style> | |
/* Hide Streamlit's default header, footer, and hamburger menu */ | |
#MainMenu, header, footer { visibility: hidden; } | |
/* Remove padding from the main block container for a full-width feel */ | |
.block-container { | |
padding: 0 !important; | |
} | |
div.stButton > button { | |
background: rgba(255, 255, 255, 0.2); | |
color: orange !important; /* White text */ | |
font-family: "Times New Roman " !important; /* Font */ | |
font-size: 18px !important; /* Font size */ | |
font-weight: bold !important; /* Bold text */ | |
padding: 10px 20px; /* Padding for buttons */ | |
border: none; /* Remove border */ | |
border-radius: 35px; /* Rounded corners */ | |
box-shadow: 0px 4px 10px rgba(0, 0, 0, 0.2); /* Shadow effect */ | |
transition: all 0.3s ease-in-out; /* Smooth transition */ | |
display: flex; | |
align-items: center; | |
justify-content: center; | |
margin: 10px 0; | |
width:190px; | |
height:50px; | |
margin-top:5px; | |
} | |
div[data-testid="stSelectbox"] | |
{ | |
background-color: white !important; | |
position: relative; | |
border-bottom:1px solid #ccc; | |
border-radius:0px; | |
} | |
div[data-testid="stTextInput"]{ | |
} | |
div[data-testid="stTextInput"] > div >div { | |
background-color: rgba(255, 158, 87, 0.12) !important; | |
} | |
div[data-testid="stTextInputRootElement"]{ | |
border: 1px solid white !important; | |
} | |
/* Hover effect */ | |
div.stButton > button:hover { | |
background: rgba(255, 255, 255, 0.2); | |
box-shadow: 0px 6px 12px rgba(0, 0, 0, 0.4); /* Enhanced shadow on hover */ | |
transform: scale(1.05); /* Slightly enlarge button */ | |
transform: scale(1.1); /* Slight zoom on hover */ | |
box-shadow: 0px 4px 12px rgba(255, 255, 255, 0.4); /* Glow effect */ | |
} | |
/* Styling the sidebar to have a modern, dark look */ | |
section[data-testid="stSidebar"] { | |
backdrop-filter: blur(10px); | |
background: rgba(255, 255, 255, 0.15); | |
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.01); | |
height:100px; | |
[data-testid="stSidebar"] h2 { | |
color: #FFFFFF; /* White headers in the sidebar */ | |
font-family:time new roman !important; | |
} | |
[data-testid="stSidebar"] .st-emotion-cache-1629p8f a { | |
color: #94A3B8; /* Lighter text color for links */ | |
font-family:time new roman !important; | |
} | |
[data-testid="stImageContainer"]>img{ | |
max-width:70% !important; | |
margin-top:-70px; | |
} | |
div[data-testid="stMarkdownContainer"] >p{ | |
font-family:time new roman !important; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
# --- Python Backend Functions --- | |
# Outside of any function | |
class GRUModel(nn.Module): | |
def __init__(self, input_dim=1, hidden_dim=100, num_layers=2, output_dim=1, dropout_prob=0.2): | |
super(GRUModel, self).__init__() | |
self.gru = nn.GRU(input_dim, hidden_dim, num_layers, batch_first=True, dropout=dropout_prob) | |
self.fc = nn.Linear(hidden_dim, output_dim) | |
def forward(self, x): | |
h0 = torch.zeros(2, x.size(0), 100).to(x.device) | |
out, _ = self.gru(x, h0) | |
return self.fc(out[:, -1, :]) | |
class BiLSTMModel(nn.Module): | |
def __init__(self): | |
super(BiLSTMModel, self).__init__() | |
self.lstm = nn.LSTM(input_size=1, hidden_size=100, num_layers=1, batch_first=True, dropout=0.2, bidirectional=True) | |
self.fc = nn.Linear(200, 1) | |
def forward(self, x): | |
h0 = torch.zeros(2, x.size(0), 100) | |
c0 = torch.zeros(2, x.size(0), 100) | |
out, _ = self.lstm(x, (h0, c0)) | |
return self.fc(out[:, -1, :]) | |
def load_model_from_disk(path, model_type): | |
model = BiLSTMModel() if model_type == "Bi-Directional LSTM" else GRUModel() | |
state = torch.load(path, map_location=torch.device("cpu")) | |
model.load_state_dict(state['model_state_dict'] if 'model_state_dict' in state else state) | |
model.eval() | |
return model | |
def load_scripted_model(path): | |
model = torch.jit.load(path, map_location=torch.device("cpu")) | |
model.eval() | |
return model | |
def preload_models(): | |
return { | |
"Bi-Directional LSTM": load_scripted_model("bilstm_scriptes.pt"), | |
"Gated Recurrent Unit": load_model_from_disk("best_gru_model.pth", model_type="GRU") | |
} | |
MODELS = preload_models() | |
# Cache data for 15 minutes | |
def get_stock_data(ticker): | |
"""Fetches historical stock data from Yahoo Finance for the last 4 years.""" | |
end_date = datetime.now() | |
start_date = end_date - timedelta(days=4 * 365) | |
print(f"Fetching data for ticker: {ticker} from {start_date.date()} to {end_date.date()}") | |
data = yf.download(ticker, period="4y", multi_level_index=False) | |
data.to_csv("AMZN_data.csv") | |
if data.empty: | |
print(f"No data found for ticker: {ticker}") | |
return None | |
data.reset_index(inplace=True) | |
print(f"Successfully fetched {len(data)} rows for {ticker}") | |
return data | |
def predict_with_model(data, n_days, model_path, model_type, model=None)-> pd.DataFrame: | |
if model is None: | |
model = load_model_from_disk(model_path, model_type=model_type) | |
close_prices = data['Close'].values.reshape(-1, 1) | |
scaler = MinMaxScaler(feature_range=(0, 1)) | |
scaled_prices = scaler.fit_transform(close_prices) | |
sequence_length = 90 | |
if len(scaled_prices) < sequence_length: | |
raise ValueError(f"Not enough historical data ({len(scaled_prices)} points) to create a sequence of {sequence_length} for prediction.") | |
last_sequence = scaled_prices[-sequence_length:] | |
current_seq = torch.tensor(last_sequence.reshape(1, sequence_length, 1), dtype=torch.float32) | |
predictions_scaled = [] | |
with torch.no_grad(): | |
for _ in range(n_days): | |
pred = model(current_seq) | |
predictions_scaled.append(pred.item()) | |
next_input = pred.view(1, 1, 1) | |
current_seq = torch.cat((current_seq[:, 1:, :], next_input), dim=1) | |
predictions = scaler.inverse_transform(np.array(predictions_scaled).reshape(-1, 1)).flatten() | |
print("predictions",predictions) | |
last_date = pd.to_datetime(data['Date'].iloc[-1]) | |
future_dates = [last_date + timedelta(days=i) for i in range(1, n_days + 1)] | |
prediction_df = pd.DataFrame({'Date': future_dates, 'Predicted Price': predictions}) | |
historical_returns = data['Close'].pct_change().dropna() | |
volatility = historical_returns.std() if not historical_returns.empty else 0.01 | |
error_std_growth = volatility * np.sqrt(np.arange(1, n_days + 1)) | |
prediction_df['Upper CI'] = predictions * (1 + 1.96 * error_std_growth) | |
prediction_df['Lower CI'] = predictions * (1 - 1.96 * error_std_growth) | |
return prediction_df | |
# --- Streamlit Sidebar Controls --- | |
with st.sidebar: | |
st.image("logo2.png", use_container_width=True) | |
st.markdown("Dashboard Controls") | |
ticker = st.text_input("Stock Ticker", st.session_state.get('last_ticker', "AMZN"), disabled=True).upper() | |
model_type = st.selectbox( | |
"Prediction Model", | |
("Bi-Directional LSTM", "Gated Recurrent Unit"), | |
key="model_choice", | |
help="Select the neural network architecture for prediction." | |
) | |
prediction_days = st.slider("Prediction Horizon (Days)", 7, 21, st.session_state.get('last_prediction_days', 7)) | |
if st.button("21 days ahead of the market", use_container_width=True, disabled=True): | |
st.session_state.run_button_clicked = True | |
st.session_state.loading = True | |
st.session_state.last_ticker = ticker | |
st.session_state.last_prediction_days = prediction_days | |
st.session_state.error = None | |
print("Generate Dashboard button clicked. Loading state set to True.") | |
st.rerun() | |
# Check if model or prediction days have changed | |
if ( | |
ticker != st.session_state.get('last_ticker', '') or | |
model_type != st.session_state.get('last_model_type', '') or | |
prediction_days != st.session_state.get('last_prediction_days', 7) | |
): | |
st.session_state.run_button_clicked = True | |
st.session_state.loading = True | |
st.session_state.last_ticker = ticker | |
st.session_state.last_model_type = model_type | |
st.session_state.last_prediction_days = prediction_days | |
# --- Main Application Logic --- | |
if st.session_state.run_button_clicked: | |
model = MODELS[model_type] | |
print(f"Inside main logic block. Current loading state: {st.session_state.loading}") | |
try: | |
if os.path.exists("AMZN_data.csv"): | |
st.session_state.data = pd.read_csv("AMZN_data.csv") | |
else: | |
st.session_state.data = get_stock_data(ticker) | |
if st.session_state.data is None: | |
st.session_state.error = f"Could not fetch data for ticker '{ticker}'. It may be an invalid symbol or network issue." | |
else: | |
model_path = "best_bilstm_model.pth" if model_type == "Bi-Directional LSTM" else "best_gru_model.pth" | |
st.session_state.predictions = predict_with_model( | |
st.session_state.data, prediction_days, model_path=None, model_type=model_type, model=model | |
) | |
print("model",model) | |
print("data", st.session_state.data) | |
st.session_state.error = None | |
except FileNotFoundError as e: | |
st.session_state.error = str(e) | |
print(f"Caught FileNotFoundError: {e}") | |
except ValueError as e: | |
st.session_state.error = str(e) | |
print(f"Caught ValueError: {e}") | |
except Exception as e: | |
st.session_state.error = f"An unexpected error occurred: {str(e)}" | |
print(f"Caught general Exception: {e}") | |
st.session_state.loading = False | |
st.session_state.run_button_clicked = False | |
print(f"Processing complete. Loading state set to False. Error: {st.session_state.error}") | |
st.rerun() | |
# --- Data Preparation for Front-End --- | |
historical_data_json = 'null' | |
prediction_data_json = 'null' | |
is_loading_js = str(st.session_state.get('loading', False)).lower() | |
error_message_js = 'null' | |
if st.session_state.get('error'): | |
error_message_js = f"'{st.session_state.error}'" # Pass error to JS | |
if st.session_state.data is not None and st.session_state.get('error') is None: | |
historical_data_json = st.session_state.data.to_json(orient='split', date_format='iso') | |
prediction_data_json = st.session_state.predictions.to_json(orient='split', date_format='iso') | |
# --- HTML Front-End --- | |
html_code = f""" | |
<!DOCTYPE html> | |
<html lang="en"> | |
<head> | |
<meta charset="UTF-8"> | |
<meta name="viewport" content="width=device-width, initial-scale=1.0"> | |
<title>Stock Intelligence Dashboard</title> | |
<script src="https://cdn.tailwindcss.com"></script> | |
<script src="https://cdn.jsdelivr.net/npm/[email protected]/dist/chart.umd.js"></script> | |
<script src="https://cdn.jsdelivr.net/npm/[email protected]/dist/chartjs-adapter-date-fns.bundle.min.js"></script> | |
<link href="https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700;800&display=swap" rel="stylesheet"> | |
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.2/css/all.min.css"> | |
<style> | |
body {{ font-family: 'time new roman'; background-color: #f1f5f9;scrollbar-width: 2px !important; scrollbar-color: rgba(100, 100, 100, 0.4) transparent;}} | |
.metric-card, .info-card {{ background-color: #ffffff; border-radius: 1rem; box-shadow: 0 4px 6px -1px rgb(0 0 0 / 0.1), 0 2px 4px -2px rgb(0 0 0 / 0.1); transition: all 0.3s ease-in-out; border: 1px solid #e2e8f0; }} | |
.metric-card:hover {{ transform: translateY(-5px); box-shadow: 0 10px 15px -3px rgb(0 0 0 / 0.1), 0 4px 6px -4px rgb(0 0 0 / 0.1); }} | |
.positive {{ color: #10B981; }} | |
.negative {{ color: #EF4444; }} | |
.neutral {{ color: #64748b; }} | |
::-webkit-scrollbar {{ | |
width: 6px; | |
}} | |
::-webkit-scrollbar-thumb {{ | |
background-color: rgba(100, 100, 100, 0.4); | |
border-radius: 3px; | |
}} | |
::-webkit-scrollbar-track {{ | |
background: transparent; | |
}} | |
#predictionTable table {{ width: 100%; border-collapse: collapse; }} | |
#predictionTable th, #predictionTable td {{ padding: 0.75rem 1rem; text-align: left; border-bottom: 1px solid #e2e8f0; }} | |
#predictionTable th {{ background-color: #f8fafc; font-weight: 600; font-size: 0.75rem; text-transform: uppercase; letter-spacing: 0.05em; color: #64748b; }} | |
#loading-overlay {{ position: fixed; inset: 0; background-color: rgba(255, 255, 255, 0.8); z-index: 100; display: flex; align-items: center; justify-content: center; backdrop-filter: blur(4px); transition: opacity 0.3s ease; }} | |
.spinner {{ width: 56px; height: 56px; border: 5px solid #3b82f6; border-bottom-color: transparent; border-radius: 50%; display: inline-block; box-sizing: border-box; animation: spin 1s linear infinite; }} | |
@keyframes spin {{ 0% {{ transform: rotate(0deg); }} 100% {{ transform: rotate(360deg); }} }} | |
.hidden {{ display: none !important; }} | |
.error-message {{ color: #EF4444; font-weight: 600; text-align: center; margin-top: 20px; padding: 15px; background-color: #fee2e2; border-radius: 0.5rem; border: 1px solid #ef4444; }} | |
</style> | |
</head> | |
<body class="antialiased text-slate-800"> | |
<main id="content-wrapper"> | |
<header class="bg-white/80 backdrop-blur-lg sticky top-0 z-50 border-b border-slate-200"> | |
<div class="max-w-8xl mx-auto px-4 sm:px-6 lg:px-8"> | |
<div class="flex items-center justify-between h-16"> | |
<div class="flex items-center"> | |
<i class="fas fa-chart-line text-2xl text-orange-400"></i> | |
<h1 id="dashboard-title" class="text-xl font-bold text-slate-900 ml-3">{ticker} Intelligence Dashboard</h1> | |
</div> | |
<div class="text-sm text-slate-500 flex items-center"> | |
<div id="status-message" class="text-center text-sm text-slate-500 mt-4 hidden">Loading updated data...</div> | |
<i class="fas fa-rocket mr-2 text-orange-400"></i> Powered by a <span class="font-semibold text-yellow-600 ml-1">{model_type}</span>  model | |
</div> | |
</div> | |
</div> | |
</header> | |
<div class="p-4 sm:p-6 lg:p-8"> | |
<div class="max-w-8xl mx-auto"> | |
<div id="dashboard-error-message" class="hidden error-message"></div> | |
<div class="grid grid-cols-1 sm:grid-cols-2 lg:grid-cols-4 gap-6 mb-8" id="metrics-grid"></div> | |
<div class="grid grid-cols-1 lg:grid-cols-3 gap-8"> | |
<div class="lg:col-span-2 space-y-8"> | |
<div class="info-card p-4 sm:p-6"> | |
<canvas id="priceChart" style="height: 350px;"></canvas> | |
</div> | |
<div class="info-card p-4 sm:p-6"> | |
<canvas id="volumeChart" style="height: 200px;"></canvas> | |
</div> | |
<div id="predictionDetailsContainer" class="info-card p-4 sm:p-6 hidden"> | |
<h3 class="text-lg font-semibold mb-4 text-slate-800">AI Prediction Details</h3> | |
<div class="overflow-x-auto" id="predictionTable"></div> | |
</div> | |
</div> | |
<div class="lg:col-span-1 space-y-8"> | |
<div class="info-card p-6"> | |
<h3 class="text-lg font-semibold mb-4 text-slate-800 flex items-center"><i class="fas fa-robot mr-3 text-orange-400"></i> AI Prediction Summary</h3> | |
<div id="predictionResult" class="mt-4 text-center"></div> | |
</div> | |
<div class="info-card p-6"> | |
<h3 class="text-lg font-semibold mb-4 text-slate-800">Technical Summary</h3> | |
<div class="space-y-3" id="tech-summary"></div> | |
</div> | |
</div> | |
</div> | |
</div> | |
</div> | |
</main> | |
<script> | |
document.addEventListener('DOMContentLoaded', function () {{ | |
const {{ | |
LineController, | |
LineElement, | |
PointElement, | |
LinearScale, | |
TimeScale, | |
Legend, | |
Tooltip, | |
BarController, | |
BarElement, | |
CategoryScale // Although you use TimeScale for X, CategoryScale might be needed for other internal reasons or for completeness for Bar charts | |
}} = Chart; | |
Chart.register( | |
LineController, | |
LineElement, | |
PointElement, | |
LinearScale, | |
TimeScale, | |
Legend, | |
Tooltip, | |
BarController, | |
BarElement, | |
CategoryScale | |
); | |
console.log("JS: Chart.js components registered."); | |
const historicalDataJson = {historical_data_json}; | |
const predictionDataJson = {prediction_data_json}; | |
const isLoading = {is_loading_js}; | |
const errorMessage = {error_message_js}; // Now receiving Python error | |
console.log("JS: DOMContentLoaded. Initial isLoading:", isLoading, "Error:", errorMessage); | |
const loadingOverlay = document.getElementById('loading-overlay'); | |
const contentWrapper = document.getElementById('content-wrapper'); | |
const metricsGridEl = document.getElementById('metrics-grid'); | |
const techSummaryEl = document.getElementById('tech-summary'); | |
const predictionResultEl = document.getElementById('predictionResult'); | |
const predictionDetailsContainerEl = document.getElementById('predictionDetailsContainer'); | |
const predictionTableEl = document.getElementById('predictionTable'); | |
const dashboardErrorMessageEl = document.getElementById('dashboard-error-message'); | |
let priceChart; | |
let volumeChart; | |
function parseData(jsonData) {{ | |
try {{ | |
if (!jsonData || !jsonData.columns) return null; | |
return {{ | |
dates: jsonData.data.map(row => new Date(row[jsonData.columns.indexOf('Date')])), | |
prices: jsonData.data.map(row => row[jsonData.columns.indexOf('Close')]), | |
volumes: jsonData.data.map(row => row[jsonData.columns.indexOf('Volume')]), | |
highs: jsonData.data.map(row => row[jsonData.columns.indexOf('High')]), | |
}}; | |
}} catch (e) {{ | |
console.error("JS: Error parsing historical data:", e); | |
return null; | |
}} | |
}} | |
function parsePredictions(jsonData) {{ | |
try {{ | |
if (!jsonData || !jsonData.columns) return []; | |
return jsonData.data.map(row => ({{ | |
x: new Date(row[jsonData.columns.indexOf('Date')]), | |
y: row[jsonData.columns.indexOf('Predicted Price')], | |
upperCI: row[jsonData.columns.indexOf('Upper CI')], | |
lowerCI: row[jsonData.columns.indexOf('Lower CI')] | |
}})); | |
}} catch (e) {{ | |
console.error("JS: Error parsing prediction data:", e); | |
return []; | |
}} | |
}} | |
function displayMetric(elementId, value, prefix = '', suffix = '', decimals = 0) {{ | |
const el = document.getElementById(elementId); | |
if (el) {{ | |
el.textContent = prefix + value.toLocaleString(undefined, {{ minimumFractionDigits: decimals, maximumFractionDigits: decimals }}) + suffix; | |
}} | |
}} | |
function updateMetrics(data) {{ | |
if (!data || data.prices.length < 2) {{ | |
metricsGridEl.innerHTML = `<div class="col-span-full text-center text-slate-500 p-4">Not enough historical data to display metrics.</div>`; | |
return; | |
}} | |
const currentPrice = data.prices[data.prices.length - 1]; | |
const prevPrice = data.prices[data.prices.length - 2]; | |
const change = currentPrice - prevPrice; | |
const changePct = (change / prevPrice) * 100; | |
const volume = data.volumes[data.volumes.length - 1]; | |
const sharesOutstanding = 10.33 * 1e9; // Example value | |
const marketCap = currentPrice * sharesOutstanding; | |
const metrics = [ | |
{{ id: 'price', title: 'Current Price', value: currentPrice, change: `${{change >= 0 ? '+' : ''}}${{change.toFixed(2)}} (${{changePct.toFixed(2)}}%)`, status: change >= 0 ? 'positive' : 'negative', icon: 'fa-dollar-sign', prefix: '$', decimals: 2 }}, | |
{{ id: 'market-cap', title: 'Market Cap', value: marketCap, change: 'USD', status: 'neutral', icon: 'fa-building', prefix: '$', suffix: '', decimals: 2, isCurrency: true }}, | |
{{ id: 'volume', title: 'Daily Volume', value: volume, change: 'Shares Traded', status: 'neutral', icon: 'fa-chart-bar', suffix: '', decimals: 0 }}, | |
{{ id: '52-week-high', title: '52-Week High', value: Math.max(...data.highs.slice(-252)), change: 'Annual Peak', status: 'neutral', icon: 'fa-arrow-trend-up', prefix: '$', decimals: 2 }}, | |
]; | |
metricsGridEl.innerHTML = metrics.map(metric => `<div class="metric-card p-5"><div class="flex items-center justify-between"><p class="text-sm font-medium text-slate-500">${{metric.title}}</p><div class="text-2xl text-slate-300"><i class="fas ${{metric.icon}}"></i></div></div><p class="text-3xl font-bold text-slate-900 mt-2" id="${{metric.id}}">0</p><p class="text-xs ${{metric.status}} mt-1 font-semibold">${{metric.change}}</p></div>`).join(''); | |
metrics.forEach(metric => {{ | |
let displayValue = metric.value; | |
let displaySuffix = metric.suffix; | |
let displayDecimals = metric.decimals; | |
if (metric.isCurrency) {{ | |
if (metric.value >= 1e12) {{ | |
displayValue = metric.value / 1e12; | |
displaySuffix = 'T'; | |
displayDecimals = 2; | |
}} else if (metric.value >= 1e9) {{ | |
displayValue = metric.value / 1e9; | |
displaySuffix = 'B'; | |
displayDecimals = 2; | |
}} else if (metric.value >= 1e6) {{ | |
displayValue = metric.value / 1e6; | |
displaySuffix = 'M'; | |
displayDecimals = 2; | |
}} | |
}} | |
if (metric.id === 'volume') {{ | |
if (metric.value >= 1e9) {{ | |
displayValue = metric.value / 1e9; | |
displaySuffix = 'B'; | |
displayDecimals = 2; | |
}} else if (metric.value >= 1e6) {{ | |
displayValue = metric.value / 1e6; | |
displaySuffix = 'M'; | |
displayDecimals = 2; | |
}} else if (metric.value >= 1e3) {{ | |
displayValue = metric.value / 1e3; | |
displaySuffix = 'K'; | |
displayDecimals = 2; | |
}} | |
}} | |
displayMetric(metric.id, displayValue, metric.prefix || '', displaySuffix, displayDecimals); | |
}}); | |
}} | |
function updateTechSummary(data) {{ | |
if (!data || data.prices.length < 50) {{ // Need enough data for 50-day SMA | |
techSummaryEl.innerHTML = '<p class="text-sm text-slate-500">Not enough data for full technical analysis (min 50 days required).</p>'; | |
return; | |
}} | |
const prices = data.prices; | |
const lastPrice = prices[prices.length - 1]; | |
// Ensure slice has enough elements | |
const sma20 = prices.slice(-20).length >= 20 ? prices.slice(-20).reduce((a, b) => a + b, 0) / 20 : NaN; | |
const sma50 = prices.slice(-50).length >= 50 ? prices.slice(-50).reduce((a, b) => a + b, 0) / 50 : NaN; | |
let gains = []; | |
let losses = []; | |
for (let i = 1; i < prices.length; i++) {{ | |
let diff = prices[i] - prices[i-1]; | |
if (diff > 0) {{ | |
gains.push(diff); | |
losses.push(0); | |
}} else {{ | |
gains.push(0); | |
losses.push(Math.abs(diff)); | |
}} | |
}} | |
let avgGain = 0; | |
let avgLoss = 0; | |
if (gains.length >= 14) {{ | |
avgGain = gains.slice(-14).reduce((a, b) => a + b, 0) / 14; | |
avgLoss = losses.slice(-14).reduce((a, b) => a + b, 0) / 14; | |
}} else if (gains.length > 0) {{ | |
avgGain = gains.reduce((a, b) => a + b, 0) / gains.length; | |
avgLoss = losses.reduce((a, b) => a + b, 0) / losses.length; | |
}} | |
let rs = (avgLoss === 0 || isNaN(avgLoss)) ? (avgGain > 0 ? Infinity : 0) : avgGain / avgLoss; | |
let rsi = 100 - (100 / (1 + rs)); | |
if (isNaN(rsi)) rsi = 0; | |
let rsiClass = 'neutral'; | |
if (rsi > 70) rsiClass = 'negative'; | |
else if (rsi < 30) rsiClass = 'positive'; | |
const summary = [ | |
{{ label: 'SMA (20 Day)', value: isNaN(sma20) ? 'N/A' : `$${{sma20.toFixed(2)}}`, status: lastPrice > sma20 ? 'positive' : (isNaN(sma20) ? 'neutral' : 'negative') }}, | |
{{ label: 'SMA (50 Day)', value: isNaN(sma50) ? 'N/A' : `$${{sma50.toFixed(2)}}`, status: lastPrice > sma50 ? 'positive' : (isNaN(sma50) ? 'neutral' : 'negative') }}, | |
{{ label: 'RSI (14 Day)', value: rsi.toFixed(1), status: rsiClass }} | |
]; | |
techSummaryEl.innerHTML = summary.map(item => `<div class="flex justify-between items-center text-sm"><span class="text-slate-600">${{item.label}}</span><span class="font-semibold ${{item.status}}">${{item.value}}</span></div>`).join(''); | |
}} | |
function renderCharts(data, predictions) {{ | |
// Render Price Chart | |
const priceCtx = document.getElementById('priceChart').getContext('2d'); | |
if (priceChart) priceChart.destroy(); | |
const priceDatasets = [ | |
{{ | |
label: 'Historical Price', | |
data: data.dates.map((d, i) => ({{x: d, y: data.prices[i]}})), | |
borderColor: '#3b82f6', | |
backgroundColor: 'rgba(59, 130, 246, 0.1)', | |
borderWidth: 2, | |
pointRadius: 0, | |
fill: true, | |
tension: 0.3 | |
}} | |
]; | |
if (predictions.length > 0) {{ | |
priceDatasets.push({{ | |
label: 'AI Prediction', | |
data: predictions, | |
borderColor: '#10b981', | |
borderWidth: 2, | |
pointRadius: 2, | |
borderDash: [5, 5], | |
fill: false, | |
tension: 0.3 | |
}}); | |
// Add confidence interval | |
const confidenceData = [ | |
...predictions.map(p => ({{x: p.x, y: p.lowerCI}})), | |
...predictions.map(p => ({{x: p.x, y: p.upperCI}})).reverse() | |
]; | |
priceDatasets.push({{ | |
label: '95% Confidence', | |
data: confidenceData, | |
fill: '1', | |
backgroundColor: 'rgba(234, 179, 8, 0.2)', | |
borderColor: 'transparent', | |
pointRadius: 0 | |
}}); | |
}} | |
priceChart = new Chart(priceCtx, {{ | |
type: 'line', // Explicitly define type | |
data: {{ datasets: priceDatasets }}, | |
options: {{ | |
responsive: true, | |
maintainAspectRatio: false, | |
scales: {{ | |
x: {{ | |
type: 'time', | |
time: {{ | |
unit: 'month', | |
tooltipFormat: 'MMM d, yyyy' | |
}}, | |
grid: {{ display: false }} | |
}}, | |
y: {{ | |
title: {{ display: true, text: 'Price (USD)' }}, | |
grid: {{ color: '#f1f5f9' }} | |
}} | |
}}, | |
plugins: {{ | |
legend: {{ | |
display: true, | |
position: 'top', | |
align: 'end' | |
}}, | |
tooltip: {{ | |
mode: 'index', | |
intersect: false, | |
callbacks: {{ | |
title: function(context) {{ | |
return context[0].label; | |
}}, | |
label: function(context) {{ | |
let label = context.dataset.label || ''; | |
if (label) label += ': '; | |
label += '$' + context.parsed.y.toFixed(2); | |
if (context.dataset.label === 'AI Prediction' && predictions.length > 0) {{ | |
const predictionPoint = predictions.find(p => p.x.getTime() === context.parsed.x); | |
if (predictionPoint) {{ | |
label += ` (CI: $${{predictionPoint.lowerCI.toFixed(2)}} - $${{predictionPoint.upperCI.toFixed(2)}})`; | |
}} | |
}} | |
return label; | |
}} | |
}} | |
}} | |
}} | |
}} | |
}}); | |
// Render Volume Chart | |
const volumeCtx = document.getElementById('volumeChart').getContext('2d'); | |
if (volumeChart) volumeChart.destroy(); | |
volumeChart = new Chart(volumeCtx, {{ | |
type: 'bar', // Explicitly define type | |
data: {{ | |
datasets: [{{ | |
label: 'Volume', | |
data: data.dates.map((d, i) => ({{x: d, y: data.volumes[i]}})), | |
backgroundColor: '#e2e8f0', | |
borderColor: '#cbd5e1', | |
borderWidth: 1 | |
}}] | |
}}, | |
options: {{ | |
responsive: true, | |
maintainAspectRatio: false, | |
scales: {{ | |
x: {{ | |
type: 'time', | |
time: {{ | |
unit: 'month' | |
}}, | |
grid: {{ display: false }} | |
}}, | |
y: {{ | |
title: {{ display: true, text: 'Volume' }}, | |
grid: {{ color: '#f1f5f9' }}, | |
ticks: {{ | |
callback: function(value) {{ | |
if (value >= 1e9) return (value / 1e9).toFixed(0) + 'B'; | |
if (value >= 1e6) return (value / 1e6).toFixed(0) + 'M'; | |
if (value >= 1e3) return (value / 1e3).toFixed(0) + 'K'; | |
return value; | |
}} | |
}} | |
}} | |
}}, | |
plugins: {{ | |
legend: {{ | |
display: false | |
}}, | |
tooltip: {{ | |
callbacks: {{ | |
label: function(context) {{ | |
let label = context.dataset.label || ''; | |
if (label) label += ': '; | |
let value = context.parsed.y; | |
if (value >= 1e9) label += (value / 1e9).toLocaleString(undefined, {{maximumFractionDigits: 1}}) + 'B'; | |
else if (value >= 1e6) label += (value / 1e6).toLocaleString(undefined, {{maximumFractionDigits: 1}}) + 'M'; | |
else if (value >= 1e3) label += (value / 1e3).toLocaleString(undefined, {{maximumFractionDigits: 1}}) + 'K'; | |
else label += value.toLocaleString(); | |
return label; | |
}} | |
}} | |
}} | |
}} | |
}} | |
}}); | |
}} | |
function displayPredictions(data, predictions) {{ | |
if (!data || predictions.length === 0) {{ | |
predictionDetailsContainerEl.classList.add('hidden'); | |
predictionResultEl.innerHTML = '<p class="text-sm text-slate-500">No predictions available or not enough data for prediction.</p>'; | |
return; | |
}} | |
predictionDetailsContainerEl.classList.remove('hidden'); | |
const lastHistoricalPrice = data.prices[data.prices.length - 1]; | |
const finalPredictedPrice = predictions[predictions.length - 1].y; | |
const changeOverall = finalPredictedPrice - lastHistoricalPrice; | |
const changePctOverall = (changeOverall / lastHistoricalPrice) * 100; | |
const statusClass = changeOverall >= 0 ? 'positive' : 'negative'; | |
predictionResultEl.innerHTML = `<p class="text-sm text-slate-500">Predicted price in ${{predictions.length}} days:</p><p class="text-3xl font-bold mt-1 ${{statusClass}}">$${{finalPredictedPrice.toFixed(2)}} <span class="text-base font-normal">(${{changeOverall >= 0 ? '+' : ''}}${{changeOverall.toFixed(2)}} / ${{changePctOverall.toFixed(2)}}%)</span></p>`; | |
const tableRows = predictions.map(p => ` | |
<tr> | |
<td>${{new Date(p.x).toLocaleDateString()}}</td> | |
<td class="font-semibold">$${{p.y.toFixed(2)}}</td> | |
<td>$${{p.lowerCI.toFixed(2)}} - $${{p.upperCI.toFixed(2)}}</td> | |
</tr> | |
`).join(''); | |
predictionTableEl.innerHTML = ` | |
<table> | |
<thead> | |
<tr> | |
<th>Date</th> | |
<th>Predicted Price</th> | |
<th>95% Confidence Interval</th> | |
</tr> | |
</thead> | |
<tbody>${{tableRows}}</tbody> | |
</table> | |
`; | |
}} | |
function loadDashboard() {{ | |
console.log("JS: loadDashboard() called. Current isLoading:", isLoading, "Error:", errorMessage); | |
const statusMessageEl = document.getElementById('status-message'); | |
// Handle loading overlay visibility | |
if (isLoading === 'true') {{ | |
statusMessageEl.classList.remove('hidden'); | |
dashboardErrorMessageEl.classList.add('hidden'); // Hide any previous error | |
return; // Stop execution, let Streamlit re-run and call again when done | |
}} else {{ | |
console.log("JS: in() called. Current isLoading:", isLoading, "Error:", errorMessage); | |
statusMessageEl.classList.add('hidden'); | |
}} | |
// Handle errors | |
if (errorMessage && errorMessage !== 'null') {{ | |
dashboardErrorMessageEl.textContent = "Error: " + errorMessage; | |
dashboardErrorMessageEl.classList.remove('hidden'); | |
// Clear existing charts if any, and other content | |
if (priceChart) priceChart.destroy(); | |
if (volumeChart) volumeChart.destroy(); | |
metricsGridEl.innerHTML = `<div class="col-span-full text-center text-slate-500 p-8 info-card">An error occurred. Please check the ticker or model.</div>`; | |
predictionDetailsContainerEl.classList.add('hidden'); | |
predictionResultEl.innerHTML = '<p class="text-sm text-slate-500">No results due to error.</p>'; | |
techSummaryEl.innerHTML = '<p class="text-sm text-slate-500">No technical summary due to error.</p>'; | |
return; | |
}} else {{ | |
dashboardErrorMessageEl.classList.add('hidden'); // Ensure error message is hidden if no error | |
}} | |
// If no error and not loading, proceed to render dashboard | |
const historicalData = parseData(historicalDataJson); | |
const predictionData = parsePredictions(predictionDataJson); | |
if (!historicalData) {{ | |
metricsGridEl.innerHTML = `<div class="col-span-full text-center text-slate-500 p-8 info-card">Click "Generate Dashboard" in the sidebar to load data.</div>`; | |
predictionDetailsContainerEl.classList.add('hidden'); | |
predictionResultEl.innerHTML = '<p class="text-sm text-slate-500">No data loaded yet.</p>'; | |
techSummaryEl.innerHTML = '<p class="text-sm text-slate-500">No data for technical summary.</p>'; | |
if (priceChart) priceChart.destroy(); | |
if (volumeChart) volumeChart.destroy(); | |
console.log("JS: No historical data available to render dashboard."); | |
return; | |
}} | |
updateMetrics(historicalData); | |
updateTechSummary(historicalData); | |
renderCharts(historicalData, predictionData); // Renamed to plural as it handles both | |
displayPredictions(historicalData, predictionData); | |
console.log("JS: Dashboard loaded successfully."); | |
}} | |
loadDashboard(); // Initial call when DOM is ready | |
}}); | |
</script> | |
</body> | |
</html> | |
""" | |
# --- Embed HTML Component in Streamlit --- | |
# No need for st.error here, as the JS will handle displaying the error in the HTML component | |
components.html(html_code, height=1200, scrolling=True) |