Abdullah Zaki commited on
Commit
574b1b5
·
1 Parent(s): 7aa3e0f

Add plotly t

Browse files
Files changed (1) hide show
  1. app.py +116 -40
app.py CHANGED
@@ -8,14 +8,9 @@ from supabase import create_client, Client
8
  import os
9
  import plotly.express as px
10
 
11
- # Initialize Supabase client with API key from environment variables
12
- SUPABASE_URL = os.getenv("SUPABASE_URL")
13
- SUPABASE_KEY = os.getenv("SUPABASE_KEY")
14
- if not SUPABASE_URL or not SUPABASE_KEY:
15
- raise ValueError("SUPABASE_URL and SUPABASE_KEY must be set as environment variables.")
16
- supabase: Client = create_client(SUPABASE_URL, SUPABASE_KEY)
17
-
18
  # Initialize Chronos-T5-Large for forecasting
 
 
19
  chronos_pipeline = ChronosPipeline.from_pretrained(
20
  "amazon/chronos-t5-large",
21
  device_map="cuda" if torch.cuda.is_available() else "cpu",
@@ -23,6 +18,7 @@ chronos_pipeline = ChronosPipeline.from_pretrained(
23
  )
24
 
25
  # Initialize Prophet-Qwen3-4B-SFT for Arabic reports
 
26
  qwen_tokenizer = AutoTokenizer.from_pretrained("radm/prophet-qwen3-4b-sft")
27
  qwen_model = AutoModelForCausalLM.from_pretrained(
28
  "radm/prophet-qwen3-4b-sft",
@@ -30,41 +26,98 @@ qwen_model = AutoModelForCausalLM.from_pretrained(
30
  torch_dtype=torch.bfloat16
31
  )
32
 
33
- def fetch_supabase_data(table_name: str = "sentiment_data") -> pd.DataFrame:
34
- """Fetch time series data from Supabase using the provided API key."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  try:
36
- response = supabase.table(table_name).select("date, sentiment").order("date", desc=False).execute()
 
 
 
 
37
  if response.data:
38
  df = pd.DataFrame(response.data)
 
39
  df['date'] = pd.to_datetime(df['date'])
 
 
40
  return df
41
  else:
42
- raise ValueError("No data found in Supabase table.")
43
  except Exception as e:
44
  raise Exception(f"Error fetching Supabase data: {str(e)}")
45
 
46
- def forecast_and_report(data_source: str, csv_file=None, prediction_length: int = 30, table_name: str = "sentiment_data"):
47
- """Run forecasting with Chronos-T5-Large and generate Arabic report with Qwen3-4B-SFT."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  try:
49
- # Load data
 
50
  if data_source == "Supabase":
51
- df = fetch_supabase_data(table_name)
52
- else:
53
- if not csv_file:
54
- return {"error": "Please upload a CSV file."}, None, None
55
- df = pd.read_csv(csv_file)
 
56
  if "sentiment" not in df.columns or "date" not in df.columns:
57
- return {"error": "CSV must contain 'date' and 'sentiment' columns."}, None, None
58
  df['date'] = pd.to_datetime(df['date'])
 
 
 
 
 
59
 
60
- # Prepare time series
 
61
  context = torch.tensor(df["sentiment"].values, dtype=torch.float32)
62
 
63
- # Run forecast
64
  forecast = chronos_pipeline.predict(context, prediction_length)
 
65
  low, median, high = np.quantile(forecast[0].numpy(), [0.1, 0.5, 0.9], axis=0)
66
 
67
- # Format forecast results
 
68
  forecast_dates = pd.date_range(start=df["date"].iloc[-1] + pd.Timedelta(days=1), periods=prediction_length, freq="D")
69
  forecast_df = pd.DataFrame({
70
  "date": forecast_dates,
@@ -73,14 +126,15 @@ def forecast_and_report(data_source: str, csv_file=None, prediction_length: int
73
  "high": high
74
  })
75
 
76
- # Create forecast plot
77
- plot_df = forecast_df.copy()
78
- fig = px.line(plot_df, x="date", y=["median", "low", "high"], title="Sentiment Forecast")
79
  fig.update_traces(line=dict(color="blue"), selector=dict(name="median"))
80
  fig.update_traces(line=dict(color="red", dash="dash"), selector=dict(name="low"))
81
  fig.update_traces(line=dict(color="green", dash="dash"), selector=dict(name="high"))
82
 
83
- # Generate Arabic report
 
84
  prompt = (
85
  "اكتب تقريراً رسمياً بالعربية يلخص توقعات المشاعر للأيام الثلاثين القادمة بناءً على البيانات التالية:\n"
86
  f"- متوسط التوقعات: {median[:5].tolist()} (أول 5 أيام)...\n"
@@ -88,37 +142,59 @@ def forecast_and_report(data_source: str, csv_file=None, prediction_length: int
88
  f"- الحد الأعلى (90%): {high[:5].tolist()}...\n"
89
  "التقرير يجب أن يكون موجزاً (200-300 كلمة)، يشرح الاتجاهات، ويستخدم لغة رسمية."
90
  )
 
91
  inputs = qwen_tokenizer(prompt, return_tensors="pt").to(qwen_model.device)
 
92
  outputs = qwen_model.generate(
93
  inputs["input_ids"],
94
- max_new_tokens=500,
95
- do_sample=True,
96
- temperature=0.7,
97
- top_p=0.9
98
  )
 
99
  report = qwen_tokenizer.decode(outputs[0], skip_special_tokens=True)
100
 
101
- return forecast_df.to_dict(), fig, report
102
-
103
  except Exception as e:
104
- return {"error": f"An error occurred: {str(e)}"}, None, None
 
105
 
106
- # Gradio interface
107
  with gr.Blocks() as demo:
108
  gr.Markdown("# Sentiment Forecasting and Arabic Reporting")
109
- data_source = gr.Radio(["Supabase", "CSV Upload"], label="Data Source", value="Supabase")
 
 
 
 
 
 
110
  csv_file = gr.File(label="Upload CSV (if CSV selected)")
111
  table_name = gr.Textbox(label="Supabase Table Name", value="sentiment_data")
112
  prediction_length = gr.Slider(1, 60, value=30, step=1, label="Prediction Length (days)")
 
113
  submit = gr.Button("Run Forecast and Generate Report")
 
 
114
  output = gr.JSON(label="Forecast Results")
115
  plot = gr.Plot(label="Forecast Plot")
116
- report = gr.Textbox(label="Arabic Report", lines=10)
 
117
 
 
118
  submit.click(
119
  fn=forecast_and_report,
120
- inputs=[data_source, csv_file, prediction_length, table_name],
121
- outputs=[output, plot, report]
 
 
 
 
 
 
 
122
  )
123
 
124
- demo.launch()
 
 
8
  import os
9
  import plotly.express as px
10
 
 
 
 
 
 
 
 
11
  # Initialize Chronos-T5-Large for forecasting
12
+ # These models are loaded once at the start of the Gradio app for efficiency.
13
+ # The device_map automatically handles CPU/GPU allocation.
14
  chronos_pipeline = ChronosPipeline.from_pretrained(
15
  "amazon/chronos-t5-large",
16
  device_map="cuda" if torch.cuda.is_available() else "cpu",
 
18
  )
19
 
20
  # Initialize Prophet-Qwen3-4B-SFT for Arabic reports
21
+ # These models are also loaded once at the start.
22
  qwen_tokenizer = AutoTokenizer.from_pretrained("radm/prophet-qwen3-4b-sft")
23
  qwen_model = AutoModelForCausalLM.from_pretrained(
24
  "radm/prophet-qwen3-4b-sft",
 
26
  torch_dtype=torch.bfloat16
27
  )
28
 
29
+ def fetch_supabase_data(supabase_url: str, supabase_key: str, table_name: str = "sentiment_data") -> pd.DataFrame:
30
+ """
31
+ Fetches time series data from Supabase using the provided URL and API key.
32
+
33
+ Args:
34
+ supabase_url (str): The URL of your Supabase project.
35
+ supabase_key (str): Your Supabase API key (anon key).
36
+ table_name (str): The name of the table to fetch data from.
37
+
38
+ Returns:
39
+ pd.DataFrame: A DataFrame containing 'date' and 'sentiment' columns.
40
+
41
+ Raises:
42
+ Exception: If there's an error connecting to Supabase or no data is found.
43
+ """
44
+ if not supabase_url or not supabase_key:
45
+ raise ValueError("Supabase URL and Key must be provided to fetch data from Supabase.")
46
+
47
  try:
48
+ # Create a new Supabase client instance for each call, using the provided URL and key.
49
+ # This allows the user to input different keys/URLs without restarting the app.
50
+ supabase_client: Client = create_client(supabase_url, supabase_key)
51
+ response = supabase_client.table(table_name).select("date, sentiment").order("date", desc=False).execute()
52
+
53
  if response.data:
54
  df = pd.DataFrame(response.data)
55
+ # Ensure 'date' column is in datetime format
56
  df['date'] = pd.to_datetime(df['date'])
57
+ # Ensure 'sentiment' column is numeric for forecasting
58
+ df['sentiment'] = pd.to_numeric(df['sentiment'])
59
  return df
60
  else:
61
+ raise ValueError(f"No data found in Supabase table '{table_name}'.")
62
  except Exception as e:
63
  raise Exception(f"Error fetching Supabase data: {str(e)}")
64
 
65
+ def forecast_and_report(
66
+ data_source: str,
67
+ supabase_url: str, # New input for Supabase URL
68
+ supabase_key: str, # New input for Supabase Key
69
+ csv_file=None,
70
+ prediction_length: int = 30,
71
+ table_name: str = "sentiment_data"
72
+ ):
73
+ """
74
+ Runs forecasting with Chronos-T5-Large and generates an Arabic report with Qwen3-4B-SFT.
75
+
76
+ Args:
77
+ data_source (str): Specifies whether to use "Supabase" or "CSV Upload".
78
+ supabase_url (str): The Supabase project URL (used if data_source is "Supabase").
79
+ supabase_key (str): The Supabase API key (used if data_source is "Supabase").
80
+ csv_file: The uploaded CSV file (used if data_source is "CSV Upload").
81
+ prediction_length (int): The number of days to forecast.
82
+ table_name (str): The name of the Supabase table.
83
+
84
+ Returns:
85
+ tuple: A tuple containing:
86
+ - dict: Forecast results as a dictionary.
87
+ - plotly.graph_objects.Figure: A Plotly figure of the forecast.
88
+ - str: The generated Arabic report.
89
+ - str: An error message if an error occurs.
90
+ """
91
  try:
92
+ # Load data based on selected source
93
+ df = pd.DataFrame() # Initialize df to avoid UnboundLocalError
94
  if data_source == "Supabase":
95
+ df = fetch_supabase_data(supabase_url, supabase_key, table_name)
96
+ else: # data_source == "CSV Upload"
97
+ if csv_file is None:
98
+ return {"error": "Please upload a CSV file when 'CSV Upload' is selected."}, None, None, "Error: CSV file not provided."
99
+ df = pd.read_csv(csv_file.name) # Access the file path
100
+ # Basic validation for required columns in CSV
101
  if "sentiment" not in df.columns or "date" not in df.columns:
102
+ return {"error": "CSV must contain 'date' and 'sentiment' columns."}, None, None, "Error: Missing 'date' or 'sentiment' columns in CSV."
103
  df['date'] = pd.to_datetime(df['date'])
104
+ df['sentiment'] = pd.to_numeric(df['sentiment'])
105
+
106
+ # Ensure there's data to process
107
+ if df.empty:
108
+ return {"error": "No data available for forecasting or reporting."}, None, None, "Error: No data available."
109
 
110
+ # Prepare time series data for Chronos
111
+ # Ensure sentiment is float32 for the model
112
  context = torch.tensor(df["sentiment"].values, dtype=torch.float32)
113
 
114
+ # Run forecast using Chronos-T5-Large pipeline
115
  forecast = chronos_pipeline.predict(context, prediction_length)
116
+ # Calculate quantiles for low, median, and high predictions
117
  low, median, high = np.quantile(forecast[0].numpy(), [0.1, 0.5, 0.9], axis=0)
118
 
119
+ # Format forecast results into a DataFrame
120
+ # Generate future dates starting from the day after the last historical date
121
  forecast_dates = pd.date_range(start=df["date"].iloc[-1] + pd.Timedelta(days=1), periods=prediction_length, freq="D")
122
  forecast_df = pd.DataFrame({
123
  "date": forecast_dates,
 
126
  "high": high
127
  })
128
 
129
+ # Create forecast plot using Plotly
130
+ # Combine historical data for plotting if desired, but here we plot only forecast
131
+ fig = px.line(forecast_df, x="date", y=["median", "low", "high"], title="Sentiment Forecast")
132
  fig.update_traces(line=dict(color="blue"), selector=dict(name="median"))
133
  fig.update_traces(line=dict(color="red", dash="dash"), selector=dict(name="low"))
134
  fig.update_traces(line=dict(color="green", dash="dash"), selector=dict(name="high"))
135
 
136
+ # Generate Arabic report using Prophet-Qwen3-4B-SFT
137
+ # Construct the prompt with relevant forecast snippets
138
  prompt = (
139
  "اكتب تقريراً رسمياً بالعربية يلخص توقعات المشاعر للأيام الثلاثين القادمة بناءً على البيانات التالية:\n"
140
  f"- متوسط التوقعات: {median[:5].tolist()} (أول 5 أيام)...\n"
 
142
  f"- الحد الأعلى (90%): {high[:5].tolist()}...\n"
143
  "التقرير يجب أن يكون موجزاً (200-300 كلمة)، يشرح الاتجاهات، ويستخدم لغة رسمية."
144
  )
145
+ # Tokenize the prompt and move to the model's device (CPU/GPU)
146
  inputs = qwen_tokenizer(prompt, return_tensors="pt").to(qwen_model.device)
147
+ # Generate the report text
148
  outputs = qwen_model.generate(
149
  inputs["input_ids"],
150
+ max_new_tokens=500, # Max length for the generated report
151
+ do_sample=True, # Enable sampling for more diverse text
152
+ temperature=0.7, # Control randomness (lower for less random)
153
+ top_p=0.9 # Nucleus sampling parameter
154
  )
155
+ # Decode the generated tokens back to text, skipping special tokens
156
  report = qwen_tokenizer.decode(outputs[0], skip_special_tokens=True)
157
 
158
+ return forecast_df.to_dict(), fig, report, "Success" # Return success message
 
159
  except Exception as e:
160
+ # Catch any exceptions and return an error message
161
+ return {}, None, None, f"An error occurred: {str(e)}"
162
 
163
+ # Gradio interface definition
164
  with gr.Blocks() as demo:
165
  gr.Markdown("# Sentiment Forecasting and Arabic Reporting")
166
+
167
+ # Input components for Supabase credentials and data source selection
168
+ with gr.Row():
169
+ data_source = gr.Radio(["Supabase", "CSV Upload"], label="Data Source", value="Supabase")
170
+ supabase_url = gr.Textbox(label="Supabase URL", placeholder="e.g., https://your-project-ref.supabase.co", interactive=True)
171
+ supabase_key = gr.Textbox(label="Supabase Key", placeholder="Your Supabase anon key", type="password", interactive=True)
172
+
173
  csv_file = gr.File(label="Upload CSV (if CSV selected)")
174
  table_name = gr.Textbox(label="Supabase Table Name", value="sentiment_data")
175
  prediction_length = gr.Slider(1, 60, value=30, step=1, label="Prediction Length (days)")
176
+
177
  submit = gr.Button("Run Forecast and Generate Report")
178
+
179
+ # Output components for results
180
  output = gr.JSON(label="Forecast Results")
181
  plot = gr.Plot(label="Forecast Plot")
182
+ report = gr.Textbox(label="Arabic Report", lines=10, rtl=True, show_copy_button=True) # Added rtl=True for Arabic display
183
+ status_message = gr.Textbox(label="Status", interactive=False) # For displaying success/error messages
184
 
185
+ # Define the click event handler for the submit button
186
  submit.click(
187
  fn=forecast_and_report,
188
+ inputs=[
189
+ data_source,
190
+ supabase_url,
191
+ supabase_key,
192
+ csv_file,
193
+ prediction_length,
194
+ table_name
195
+ ],
196
+ outputs=[output, plot, report, status_message]
197
  )
198
 
199
+ # Launch the Gradio application
200
+ demo.launch()