File size: 5,504 Bytes
387baae fd00e59 574b1b5 fd00e59 387baae fd00e59 574b1b5 fd00e59 574b1b5 fd00e59 574b1b5 fd00e59 574b1b5 fd00e59 574b1b5 387baae fd00e59 574b1b5 fd00e59 574b1b5 fd00e59 387baae 574b1b5 fd00e59 387baae 574b1b5 fd00e59 387baae fd00e59 387baae 574b1b5 fd00e59 387baae fd00e59 387baae fd00e59 387baae fd00e59 387baae 574b1b5 387baae fd00e59 574b1b5 fd00e59 387baae 574b1b5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
import gradio as gr
import pandas as pd
import numpy as np
import torch
from chronos import ChronosPipeline
import plotly.express as px
# Initialize Chronos-T5-Large for forecasting
# This model is loaded once at the start of the Gradio app for efficiency.
# The device_map automatically handles CPU/GPU allocation.
# torch_dtype=torch.bfloat16 is used for optimized performance if a compatible GPU is available.
chronos_pipeline = ChronosPipeline.from_pretrained(
"amazon/chronos-t5-large",
device_map="cuda" if torch.cuda.is_available() else "cpu",
torch_dtype=torch.bfloat16
)
def run_chronos_forecast(
csv_file: gr.File,
prediction_length: int = 30
) -> tuple[pd.DataFrame, px.line, str]:
"""
Runs time series forecasting using the Chronos-T5-Large model.
Args:
csv_file (gr.File): The uploaded CSV file containing historical data.
Must have 'date' and 'sentiment' columns.
prediction_length (int): The number of future periods (days) to forecast.
Returns:
tuple: A tuple containing:
- pd.DataFrame: A DataFrame of the forecast results (date, low, median, high).
- plotly.graph_objects.Figure: A Plotly figure visualizing the forecast.
- str: A status message (e.g., "Success" or an error message).
"""
if csv_file is None:
return pd.DataFrame(), None, "Error: Please upload a CSV file."
try:
# Read the uploaded CSV file into a pandas DataFrame
df = pd.read_csv(csv_file.name)
# Validate required columns
if "date" not in df.columns or "sentiment" not in df.columns:
return pd.DataFrame(), None, "Error: CSV must contain 'date' and 'sentiment' columns."
# Convert 'date' column to datetime objects
df['date'] = pd.to_datetime(df['date'])
# Convert 'sentiment' column to numeric, handling potential errors
df['sentiment'] = pd.to_numeric(df['sentiment'], errors='coerce')
# Drop rows where sentiment could not be converted (e.g., NaN values)
df.dropna(subset=['sentiment'], inplace=True)
if df.empty:
return pd.DataFrame(), None, "Error: No valid sentiment data found in the CSV."
# Sort data by date to ensure correct time series order
df = df.sort_values(by='date').reset_index(drop=True)
# Prepare time series data for Chronos
# Chronos expects a 1D tensor of the time series values
context = torch.tensor(df["sentiment"].values, dtype=torch.float32)
# Run forecast using Chronos-T5-Large pipeline
# The predict method returns a tensor of forecasts
forecast_tensor = chronos_pipeline.predict(context, prediction_length)
# Calculate quantiles (10%, 50% (median), 90%) for the forecast
# forecast_tensor[0] selects the first (and usually only) batch of predictions
low, median, high = np.quantile(forecast_tensor[0].numpy(), [0.1, 0.5, 0.9], axis=0)
# Generate future dates for the forecast results
# Start from the day after the last historical date
last_historical_date = df["date"].iloc[-1]
forecast_dates = pd.date_range(start=last_historical_date + pd.Timedelta(days=1),
periods=prediction_length,
freq="D")
# Create a DataFrame for the forecast results
forecast_df = pd.DataFrame({
"date": forecast_dates,
"low": low,
"median": median,
"high": high
})
# Create forecast plot using Plotly
fig = px.line(forecast_df, x="date", y=["median", "low", "high"], title="Sentiment Forecast")
fig.update_traces(line=dict(color="blue", width=3), selector=dict(name="median"))
fig.update_traces(line=dict(color="red", dash="dash"), selector=dict(name="low"))
fig.update_traces(line=dict(color="green", dash="dash"), selector=dict(name="high"))
fig.update_layout(hovermode="x unified", title_x=0.5) # Improve hover interactivity and center title
return forecast_df, fig, "Forecast generated successfully!"
except Exception as e:
# Catch any exceptions and return an error message to the user
return pd.DataFrame(), None, f"An error occurred: {str(e)}"
# Gradio interface definition
with gr.Blocks() as demo:
gr.Markdown("# Chronos Time Series Forecasting")
gr.Markdown("Upload a CSV file containing historical data with 'date' and 'sentiment' columns to get a sentiment forecast.")
with gr.Row():
csv_input = gr.File(label="Upload Historical Data (CSV)")
prediction_length_slider = gr.Slider(
1, 60, value=30, step=1, label="Prediction Length (days)"
)
run_button = gr.Button("Generate Forecast")
with gr.Tab("Forecast Plot"):
forecast_plot_output = gr.Plot(label="Sentiment Forecast Plot")
with gr.Tab("Forecast Data"):
forecast_json_output = gr.DataFrame(label="Raw Forecast Data") # Changed to DataFrame for better readability
status_message_output = gr.Textbox(label="Status", interactive=False)
# Define the click event handler for the run button
run_button.click(
fn=run_chronos_forecast,
inputs=[csv_input, prediction_length_slider],
outputs=[forecast_json_output, forecast_plot_output, status_message_output]
)
# Launch the Gradio application
demo.launch()
|