Abdullah Zaki commited on
Commit
fd00e59
·
1 Parent(s): 574b1b5

Add plotly t

Browse files
Files changed (1) hide show
  1. app.py +73 -145
app.py CHANGED
@@ -3,122 +3,80 @@ import pandas as pd
3
  import numpy as np
4
  import torch
5
  from chronos import ChronosPipeline
6
- from transformers import AutoTokenizer, AutoModelForCausalLM
7
- from supabase import create_client, Client
8
- import os
9
  import plotly.express as px
10
 
11
  # Initialize Chronos-T5-Large for forecasting
12
- # These models are loaded once at the start of the Gradio app for efficiency.
13
  # The device_map automatically handles CPU/GPU allocation.
 
14
  chronos_pipeline = ChronosPipeline.from_pretrained(
15
  "amazon/chronos-t5-large",
16
  device_map="cuda" if torch.cuda.is_available() else "cpu",
17
  torch_dtype=torch.bfloat16
18
  )
19
 
20
- # Initialize Prophet-Qwen3-4B-SFT for Arabic reports
21
- # These models are also loaded once at the start.
22
- qwen_tokenizer = AutoTokenizer.from_pretrained("radm/prophet-qwen3-4b-sft")
23
- qwen_model = AutoModelForCausalLM.from_pretrained(
24
- "radm/prophet-qwen3-4b-sft",
25
- device_map="cuda" if torch.cuda.is_available() else "cpu",
26
- torch_dtype=torch.bfloat16
27
- )
28
-
29
- def fetch_supabase_data(supabase_url: str, supabase_key: str, table_name: str = "sentiment_data") -> pd.DataFrame:
30
  """
31
- Fetches time series data from Supabase using the provided URL and API key.
32
 
33
  Args:
34
- supabase_url (str): The URL of your Supabase project.
35
- supabase_key (str): Your Supabase API key (anon key).
36
- table_name (str): The name of the table to fetch data from.
37
 
38
  Returns:
39
- pd.DataFrame: A DataFrame containing 'date' and 'sentiment' columns.
40
-
41
- Raises:
42
- Exception: If there's an error connecting to Supabase or no data is found.
43
  """
44
- if not supabase_url or not supabase_key:
45
- raise ValueError("Supabase URL and Key must be provided to fetch data from Supabase.")
46
 
47
  try:
48
- # Create a new Supabase client instance for each call, using the provided URL and key.
49
- # This allows the user to input different keys/URLs without restarting the app.
50
- supabase_client: Client = create_client(supabase_url, supabase_key)
51
- response = supabase_client.table(table_name).select("date, sentiment").order("date", desc=False).execute()
52
-
53
- if response.data:
54
- df = pd.DataFrame(response.data)
55
- # Ensure 'date' column is in datetime format
56
- df['date'] = pd.to_datetime(df['date'])
57
- # Ensure 'sentiment' column is numeric for forecasting
58
- df['sentiment'] = pd.to_numeric(df['sentiment'])
59
- return df
60
- else:
61
- raise ValueError(f"No data found in Supabase table '{table_name}'.")
62
- except Exception as e:
63
- raise Exception(f"Error fetching Supabase data: {str(e)}")
64
-
65
- def forecast_and_report(
66
- data_source: str,
67
- supabase_url: str, # New input for Supabase URL
68
- supabase_key: str, # New input for Supabase Key
69
- csv_file=None,
70
- prediction_length: int = 30,
71
- table_name: str = "sentiment_data"
72
- ):
73
- """
74
- Runs forecasting with Chronos-T5-Large and generates an Arabic report with Qwen3-4B-SFT.
75
 
76
- Args:
77
- data_source (str): Specifies whether to use "Supabase" or "CSV Upload".
78
- supabase_url (str): The Supabase project URL (used if data_source is "Supabase").
79
- supabase_key (str): The Supabase API key (used if data_source is "Supabase").
80
- csv_file: The uploaded CSV file (used if data_source is "CSV Upload").
81
- prediction_length (int): The number of days to forecast.
82
- table_name (str): The name of the Supabase table.
 
 
 
83
 
84
- Returns:
85
- tuple: A tuple containing:
86
- - dict: Forecast results as a dictionary.
87
- - plotly.graph_objects.Figure: A Plotly figure of the forecast.
88
- - str: The generated Arabic report.
89
- - str: An error message if an error occurs.
90
- """
91
- try:
92
- # Load data based on selected source
93
- df = pd.DataFrame() # Initialize df to avoid UnboundLocalError
94
- if data_source == "Supabase":
95
- df = fetch_supabase_data(supabase_url, supabase_key, table_name)
96
- else: # data_source == "CSV Upload"
97
- if csv_file is None:
98
- return {"error": "Please upload a CSV file when 'CSV Upload' is selected."}, None, None, "Error: CSV file not provided."
99
- df = pd.read_csv(csv_file.name) # Access the file path
100
- # Basic validation for required columns in CSV
101
- if "sentiment" not in df.columns or "date" not in df.columns:
102
- return {"error": "CSV must contain 'date' and 'sentiment' columns."}, None, None, "Error: Missing 'date' or 'sentiment' columns in CSV."
103
- df['date'] = pd.to_datetime(df['date'])
104
- df['sentiment'] = pd.to_numeric(df['sentiment'])
105
-
106
- # Ensure there's data to process
107
  if df.empty:
108
- return {"error": "No data available for forecasting or reporting."}, None, None, "Error: No data available."
 
 
 
109
 
110
  # Prepare time series data for Chronos
111
- # Ensure sentiment is float32 for the model
112
  context = torch.tensor(df["sentiment"].values, dtype=torch.float32)
113
 
114
  # Run forecast using Chronos-T5-Large pipeline
115
- forecast = chronos_pipeline.predict(context, prediction_length)
116
- # Calculate quantiles for low, median, and high predictions
117
- low, median, high = np.quantile(forecast[0].numpy(), [0.1, 0.5, 0.9], axis=0)
118
 
119
- # Format forecast results into a DataFrame
120
- # Generate future dates starting from the day after the last historical date
121
- forecast_dates = pd.date_range(start=df["date"].iloc[-1] + pd.Timedelta(days=1), periods=prediction_length, freq="D")
 
 
 
 
 
 
 
 
 
122
  forecast_df = pd.DataFrame({
123
  "date": forecast_dates,
124
  "low": low,
@@ -127,73 +85,43 @@ def forecast_and_report(
127
  })
128
 
129
  # Create forecast plot using Plotly
130
- # Combine historical data for plotting if desired, but here we plot only forecast
131
  fig = px.line(forecast_df, x="date", y=["median", "low", "high"], title="Sentiment Forecast")
132
- fig.update_traces(line=dict(color="blue"), selector=dict(name="median"))
133
  fig.update_traces(line=dict(color="red", dash="dash"), selector=dict(name="low"))
134
  fig.update_traces(line=dict(color="green", dash="dash"), selector=dict(name="high"))
 
135
 
136
- # Generate Arabic report using Prophet-Qwen3-4B-SFT
137
- # Construct the prompt with relevant forecast snippets
138
- prompt = (
139
- "اكتب تقريراً رسمياً بالعربية يلخص توقعات المشاعر للأيام الثلاثين القادمة بناءً على البيانات التالية:\n"
140
- f"- متوسط التوقعات: {median[:5].tolist()} (أول 5 أيام)...\n"
141
- f"- الحد الأدنى (10%): {low[:5].tolist()}...\n"
142
- f"- الحد الأعلى (90%): {high[:5].tolist()}...\n"
143
- "التقرير يجب أن يكون موجزاً (200-300 كلمة)، يشرح الاتجاهات، ويستخدم لغة رسمية."
144
- )
145
- # Tokenize the prompt and move to the model's device (CPU/GPU)
146
- inputs = qwen_tokenizer(prompt, return_tensors="pt").to(qwen_model.device)
147
- # Generate the report text
148
- outputs = qwen_model.generate(
149
- inputs["input_ids"],
150
- max_new_tokens=500, # Max length for the generated report
151
- do_sample=True, # Enable sampling for more diverse text
152
- temperature=0.7, # Control randomness (lower for less random)
153
- top_p=0.9 # Nucleus sampling parameter
154
- )
155
- # Decode the generated tokens back to text, skipping special tokens
156
- report = qwen_tokenizer.decode(outputs[0], skip_special_tokens=True)
157
 
158
- return forecast_df.to_dict(), fig, report, "Success" # Return success message
159
  except Exception as e:
160
- # Catch any exceptions and return an error message
161
- return {}, None, None, f"An error occurred: {str(e)}"
162
 
163
  # Gradio interface definition
164
  with gr.Blocks() as demo:
165
- gr.Markdown("# Sentiment Forecasting and Arabic Reporting")
 
166
 
167
- # Input components for Supabase credentials and data source selection
168
  with gr.Row():
169
- data_source = gr.Radio(["Supabase", "CSV Upload"], label="Data Source", value="Supabase")
170
- supabase_url = gr.Textbox(label="Supabase URL", placeholder="e.g., https://your-project-ref.supabase.co", interactive=True)
171
- supabase_key = gr.Textbox(label="Supabase Key", placeholder="Your Supabase anon key", type="password", interactive=True)
172
-
173
- csv_file = gr.File(label="Upload CSV (if CSV selected)")
174
- table_name = gr.Textbox(label="Supabase Table Name", value="sentiment_data")
175
- prediction_length = gr.Slider(1, 60, value=30, step=1, label="Prediction Length (days)")
176
-
177
- submit = gr.Button("Run Forecast and Generate Report")
178
-
179
- # Output components for results
180
- output = gr.JSON(label="Forecast Results")
181
- plot = gr.Plot(label="Forecast Plot")
182
- report = gr.Textbox(label="Arabic Report", lines=10, rtl=True, show_copy_button=True) # Added rtl=True for Arabic display
183
- status_message = gr.Textbox(label="Status", interactive=False) # For displaying success/error messages
184
-
185
- # Define the click event handler for the submit button
186
- submit.click(
187
- fn=forecast_and_report,
188
- inputs=[
189
- data_source,
190
- supabase_url,
191
- supabase_key,
192
- csv_file,
193
- prediction_length,
194
- table_name
195
- ],
196
- outputs=[output, plot, report, status_message]
197
  )
198
 
199
  # Launch the Gradio application
 
3
  import numpy as np
4
  import torch
5
  from chronos import ChronosPipeline
 
 
 
6
  import plotly.express as px
7
 
8
  # Initialize Chronos-T5-Large for forecasting
9
+ # This model is loaded once at the start of the Gradio app for efficiency.
10
  # The device_map automatically handles CPU/GPU allocation.
11
+ # torch_dtype=torch.bfloat16 is used for optimized performance if a compatible GPU is available.
12
  chronos_pipeline = ChronosPipeline.from_pretrained(
13
  "amazon/chronos-t5-large",
14
  device_map="cuda" if torch.cuda.is_available() else "cpu",
15
  torch_dtype=torch.bfloat16
16
  )
17
 
18
+ def run_chronos_forecast(
19
+ csv_file: gr.File,
20
+ prediction_length: int = 30
21
+ ) -> tuple[pd.DataFrame, px.line, str]:
 
 
 
 
 
 
22
  """
23
+ Runs time series forecasting using the Chronos-T5-Large model.
24
 
25
  Args:
26
+ csv_file (gr.File): The uploaded CSV file containing historical data.
27
+ Must have 'date' and 'sentiment' columns.
28
+ prediction_length (int): The number of future periods (days) to forecast.
29
 
30
  Returns:
31
+ tuple: A tuple containing:
32
+ - pd.DataFrame: A DataFrame of the forecast results (date, low, median, high).
33
+ - plotly.graph_objects.Figure: A Plotly figure visualizing the forecast.
34
+ - str: A status message (e.g., "Success" or an error message).
35
  """
36
+ if csv_file is None:
37
+ return pd.DataFrame(), None, "Error: Please upload a CSV file."
38
 
39
  try:
40
+ # Read the uploaded CSV file into a pandas DataFrame
41
+ df = pd.read_csv(csv_file.name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
+ # Validate required columns
44
+ if "date" not in df.columns or "sentiment" not in df.columns:
45
+ return pd.DataFrame(), None, "Error: CSV must contain 'date' and 'sentiment' columns."
46
+
47
+ # Convert 'date' column to datetime objects
48
+ df['date'] = pd.to_datetime(df['date'])
49
+ # Convert 'sentiment' column to numeric, handling potential errors
50
+ df['sentiment'] = pd.to_numeric(df['sentiment'], errors='coerce')
51
+ # Drop rows where sentiment could not be converted (e.g., NaN values)
52
+ df.dropna(subset=['sentiment'], inplace=True)
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  if df.empty:
55
+ return pd.DataFrame(), None, "Error: No valid sentiment data found in the CSV."
56
+
57
+ # Sort data by date to ensure correct time series order
58
+ df = df.sort_values(by='date').reset_index(drop=True)
59
 
60
  # Prepare time series data for Chronos
61
+ # Chronos expects a 1D tensor of the time series values
62
  context = torch.tensor(df["sentiment"].values, dtype=torch.float32)
63
 
64
  # Run forecast using Chronos-T5-Large pipeline
65
+ # The predict method returns a tensor of forecasts
66
+ forecast_tensor = chronos_pipeline.predict(context, prediction_length)
 
67
 
68
+ # Calculate quantiles (10%, 50% (median), 90%) for the forecast
69
+ # forecast_tensor[0] selects the first (and usually only) batch of predictions
70
+ low, median, high = np.quantile(forecast_tensor[0].numpy(), [0.1, 0.5, 0.9], axis=0)
71
+
72
+ # Generate future dates for the forecast results
73
+ # Start from the day after the last historical date
74
+ last_historical_date = df["date"].iloc[-1]
75
+ forecast_dates = pd.date_range(start=last_historical_date + pd.Timedelta(days=1),
76
+ periods=prediction_length,
77
+ freq="D")
78
+
79
+ # Create a DataFrame for the forecast results
80
  forecast_df = pd.DataFrame({
81
  "date": forecast_dates,
82
  "low": low,
 
85
  })
86
 
87
  # Create forecast plot using Plotly
 
88
  fig = px.line(forecast_df, x="date", y=["median", "low", "high"], title="Sentiment Forecast")
89
+ fig.update_traces(line=dict(color="blue", width=3), selector=dict(name="median"))
90
  fig.update_traces(line=dict(color="red", dash="dash"), selector=dict(name="low"))
91
  fig.update_traces(line=dict(color="green", dash="dash"), selector=dict(name="high"))
92
+ fig.update_layout(hovermode="x unified", title_x=0.5) # Improve hover interactivity and center title
93
 
94
+ return forecast_df, fig, "Forecast generated successfully!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
 
96
  except Exception as e:
97
+ # Catch any exceptions and return an error message to the user
98
+ return pd.DataFrame(), None, f"An error occurred: {str(e)}"
99
 
100
  # Gradio interface definition
101
  with gr.Blocks() as demo:
102
+ gr.Markdown("# Chronos Time Series Forecasting")
103
+ gr.Markdown("Upload a CSV file containing historical data with 'date' and 'sentiment' columns to get a sentiment forecast.")
104
 
 
105
  with gr.Row():
106
+ csv_input = gr.File(label="Upload Historical Data (CSV)")
107
+ prediction_length_slider = gr.Slider(
108
+ 1, 60, value=30, step=1, label="Prediction Length (days)"
109
+ )
110
+
111
+ run_button = gr.Button("Generate Forecast")
112
+
113
+ with gr.Tab("Forecast Plot"):
114
+ forecast_plot_output = gr.Plot(label="Sentiment Forecast Plot")
115
+ with gr.Tab("Forecast Data"):
116
+ forecast_json_output = gr.DataFrame(label="Raw Forecast Data") # Changed to DataFrame for better readability
117
+
118
+ status_message_output = gr.Textbox(label="Status", interactive=False)
119
+
120
+ # Define the click event handler for the run button
121
+ run_button.click(
122
+ fn=run_chronos_forecast,
123
+ inputs=[csv_input, prediction_length_slider],
124
+ outputs=[forecast_json_output, forecast_plot_output, status_message_output]
 
 
 
 
 
 
 
 
 
125
  )
126
 
127
  # Launch the Gradio application