import gradio as gr import pandas as pd import numpy as np import torch from chronos import ChronosPipeline import plotly.express as px # Initialize Chronos-T5-Large for forecasting # This model is loaded once at the start of the Gradio app for efficiency. # The device_map automatically handles CPU/GPU allocation. # torch_dtype=torch.bfloat16 is used for optimized performance if a compatible GPU is available. chronos_pipeline = ChronosPipeline.from_pretrained( "amazon/chronos-t5-large", device_map="cuda" if torch.cuda.is_available() else "cpu", torch_dtype=torch.bfloat16 ) def run_chronos_forecast( csv_file: gr.File, prediction_length: int = 30 ) -> tuple[pd.DataFrame, px.line, str]: """ Runs time series forecasting using the Chronos-T5-Large model. Args: csv_file (gr.File): The uploaded CSV file containing historical data. Must have 'date' and 'sentiment' columns. prediction_length (int): The number of future periods (days) to forecast. Returns: tuple: A tuple containing: - pd.DataFrame: A DataFrame of the forecast results (date, low, median, high). - plotly.graph_objects.Figure: A Plotly figure visualizing the forecast. - str: A status message (e.g., "Success" or an error message). """ if csv_file is None: return pd.DataFrame(), None, "Error: Please upload a CSV file." try: # Read the uploaded CSV file into a pandas DataFrame df = pd.read_csv(csv_file.name) # Validate required columns if "date" not in df.columns or "sentiment" not in df.columns: return pd.DataFrame(), None, "Error: CSV must contain 'date' and 'sentiment' columns." # Convert 'date' column to datetime objects df['date'] = pd.to_datetime(df['date']) # Convert 'sentiment' column to numeric, handling potential errors df['sentiment'] = pd.to_numeric(df['sentiment'], errors='coerce') # Drop rows where sentiment could not be converted (e.g., NaN values) df.dropna(subset=['sentiment'], inplace=True) if df.empty: return pd.DataFrame(), None, "Error: No valid sentiment data found in the CSV." # Sort data by date to ensure correct time series order df = df.sort_values(by='date').reset_index(drop=True) # Prepare time series data for Chronos # Chronos expects a 1D tensor of the time series values context = torch.tensor(df["sentiment"].values, dtype=torch.float32) # Run forecast using Chronos-T5-Large pipeline # The predict method returns a tensor of forecasts forecast_tensor = chronos_pipeline.predict(context, prediction_length) # Calculate quantiles (10%, 50% (median), 90%) for the forecast # forecast_tensor[0] selects the first (and usually only) batch of predictions low, median, high = np.quantile(forecast_tensor[0].numpy(), [0.1, 0.5, 0.9], axis=0) # Generate future dates for the forecast results # Start from the day after the last historical date last_historical_date = df["date"].iloc[-1] forecast_dates = pd.date_range(start=last_historical_date + pd.Timedelta(days=1), periods=prediction_length, freq="D") # Create a DataFrame for the forecast results forecast_df = pd.DataFrame({ "date": forecast_dates, "low": low, "median": median, "high": high }) # Create forecast plot using Plotly fig = px.line(forecast_df, x="date", y=["median", "low", "high"], title="Sentiment Forecast") fig.update_traces(line=dict(color="blue", width=3), selector=dict(name="median")) fig.update_traces(line=dict(color="red", dash="dash"), selector=dict(name="low")) fig.update_traces(line=dict(color="green", dash="dash"), selector=dict(name="high")) fig.update_layout(hovermode="x unified", title_x=0.5) # Improve hover interactivity and center title return forecast_df, fig, "Forecast generated successfully!" except Exception as e: # Catch any exceptions and return an error message to the user return pd.DataFrame(), None, f"An error occurred: {str(e)}" # Gradio interface definition with gr.Blocks() as demo: gr.Markdown("# Chronos Time Series Forecasting") gr.Markdown("Upload a CSV file containing historical data with 'date' and 'sentiment' columns to get a sentiment forecast.") with gr.Row(): csv_input = gr.File(label="Upload Historical Data (CSV)") prediction_length_slider = gr.Slider( 1, 60, value=30, step=1, label="Prediction Length (days)" ) run_button = gr.Button("Generate Forecast") with gr.Tab("Forecast Plot"): forecast_plot_output = gr.Plot(label="Sentiment Forecast Plot") with gr.Tab("Forecast Data"): forecast_json_output = gr.DataFrame(label="Raw Forecast Data") # Changed to DataFrame for better readability status_message_output = gr.Textbox(label="Status", interactive=False) # Define the click event handler for the run button run_button.click( fn=run_chronos_forecast, inputs=[csv_input, prediction_length_slider], outputs=[forecast_json_output, forecast_plot_output, status_message_output] ) # Launch the Gradio application demo.launch()