File size: 5,504 Bytes
387baae
 
 
 
 
 
 
 
fd00e59
574b1b5
fd00e59
387baae
 
 
 
 
 
fd00e59
 
 
 
574b1b5
fd00e59
574b1b5
 
fd00e59
 
 
574b1b5
 
fd00e59
 
 
 
574b1b5
fd00e59
 
574b1b5
387baae
fd00e59
 
574b1b5
fd00e59
 
 
 
 
 
 
 
 
 
574b1b5
 
fd00e59
 
 
 
387baae
574b1b5
fd00e59
387baae
 
574b1b5
fd00e59
 
387baae
fd00e59
 
 
 
 
 
 
 
 
 
 
 
387baae
 
 
 
 
 
 
574b1b5
 
fd00e59
387baae
 
fd00e59
387baae
fd00e59
387baae
 
fd00e59
 
387baae
574b1b5
387baae
fd00e59
 
574b1b5
 
fd00e59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
387baae
 
574b1b5
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import gradio as gr
import pandas as pd
import numpy as np
import torch
from chronos import ChronosPipeline
import plotly.express as px

# Initialize Chronos-T5-Large for forecasting
# This model is loaded once at the start of the Gradio app for efficiency.
# The device_map automatically handles CPU/GPU allocation.
# torch_dtype=torch.bfloat16 is used for optimized performance if a compatible GPU is available.
chronos_pipeline = ChronosPipeline.from_pretrained(
    "amazon/chronos-t5-large",
    device_map="cuda" if torch.cuda.is_available() else "cpu",
    torch_dtype=torch.bfloat16
)

def run_chronos_forecast(
    csv_file: gr.File,
    prediction_length: int = 30
) -> tuple[pd.DataFrame, px.line, str]:
    """
    Runs time series forecasting using the Chronos-T5-Large model.

    Args:
        csv_file (gr.File): The uploaded CSV file containing historical data.
                            Must have 'date' and 'sentiment' columns.
        prediction_length (int): The number of future periods (days) to forecast.

    Returns:
        tuple: A tuple containing:
            - pd.DataFrame: A DataFrame of the forecast results (date, low, median, high).
            - plotly.graph_objects.Figure: A Plotly figure visualizing the forecast.
            - str: A status message (e.g., "Success" or an error message).
    """
    if csv_file is None:
        return pd.DataFrame(), None, "Error: Please upload a CSV file."

    try:
        # Read the uploaded CSV file into a pandas DataFrame
        df = pd.read_csv(csv_file.name)

        # Validate required columns
        if "date" not in df.columns or "sentiment" not in df.columns:
            return pd.DataFrame(), None, "Error: CSV must contain 'date' and 'sentiment' columns."

        # Convert 'date' column to datetime objects
        df['date'] = pd.to_datetime(df['date'])
        # Convert 'sentiment' column to numeric, handling potential errors
        df['sentiment'] = pd.to_numeric(df['sentiment'], errors='coerce')
        # Drop rows where sentiment could not be converted (e.g., NaN values)
        df.dropna(subset=['sentiment'], inplace=True)

        if df.empty:
            return pd.DataFrame(), None, "Error: No valid sentiment data found in the CSV."

        # Sort data by date to ensure correct time series order
        df = df.sort_values(by='date').reset_index(drop=True)

        # Prepare time series data for Chronos
        # Chronos expects a 1D tensor of the time series values
        context = torch.tensor(df["sentiment"].values, dtype=torch.float32)

        # Run forecast using Chronos-T5-Large pipeline
        # The predict method returns a tensor of forecasts
        forecast_tensor = chronos_pipeline.predict(context, prediction_length)

        # Calculate quantiles (10%, 50% (median), 90%) for the forecast
        # forecast_tensor[0] selects the first (and usually only) batch of predictions
        low, median, high = np.quantile(forecast_tensor[0].numpy(), [0.1, 0.5, 0.9], axis=0)

        # Generate future dates for the forecast results
        # Start from the day after the last historical date
        last_historical_date = df["date"].iloc[-1]
        forecast_dates = pd.date_range(start=last_historical_date + pd.Timedelta(days=1),
                                       periods=prediction_length,
                                       freq="D")

        # Create a DataFrame for the forecast results
        forecast_df = pd.DataFrame({
            "date": forecast_dates,
            "low": low,
            "median": median,
            "high": high
        })

        # Create forecast plot using Plotly
        fig = px.line(forecast_df, x="date", y=["median", "low", "high"], title="Sentiment Forecast")
        fig.update_traces(line=dict(color="blue", width=3), selector=dict(name="median"))
        fig.update_traces(line=dict(color="red", dash="dash"), selector=dict(name="low"))
        fig.update_traces(line=dict(color="green", dash="dash"), selector=dict(name="high"))
        fig.update_layout(hovermode="x unified", title_x=0.5) # Improve hover interactivity and center title

        return forecast_df, fig, "Forecast generated successfully!"

    except Exception as e:
        # Catch any exceptions and return an error message to the user
        return pd.DataFrame(), None, f"An error occurred: {str(e)}"

# Gradio interface definition
with gr.Blocks() as demo:
    gr.Markdown("# Chronos Time Series Forecasting")
    gr.Markdown("Upload a CSV file containing historical data with 'date' and 'sentiment' columns to get a sentiment forecast.")

    with gr.Row():
        csv_input = gr.File(label="Upload Historical Data (CSV)")
        prediction_length_slider = gr.Slider(
            1, 60, value=30, step=1, label="Prediction Length (days)"
        )

    run_button = gr.Button("Generate Forecast")

    with gr.Tab("Forecast Plot"):
        forecast_plot_output = gr.Plot(label="Sentiment Forecast Plot")
    with gr.Tab("Forecast Data"):
        forecast_json_output = gr.DataFrame(label="Raw Forecast Data") # Changed to DataFrame for better readability

    status_message_output = gr.Textbox(label="Status", interactive=False)

    # Define the click event handler for the run button
    run_button.click(
        fn=run_chronos_forecast,
        inputs=[csv_input, prediction_length_slider],
        outputs=[forecast_json_output, forecast_plot_output, status_message_output]
    )

# Launch the Gradio application
demo.launch()