xiezhe22 commited on
Commit
5eabbef
Β·
1 Parent(s): a6775d1

Add Default time series

Browse files
Files changed (1) hide show
  1. app.py +58 -35
app.py CHANGED
@@ -31,6 +31,14 @@ model.eval()
31
 
32
  # ─── HELPER FUNCTIONS ──────────────────────────────────────────────────────────
33
 
 
 
 
 
 
 
 
 
34
  def process_csv_file(csv_file):
35
  """Process CSV file and return DataFrame with validation"""
36
  if csv_file is None:
@@ -81,25 +89,26 @@ def process_csv_file(csv_file):
81
  except Exception as e:
82
  return None, f"Error processing file: {str(e)}"
83
 
 
 
 
 
 
 
 
84
  def preview_csv(csv_file):
85
  """Preview uploaded CSV file immediately"""
86
- if csv_file is None:
87
- return gr.LinePlot(), "Please upload a CSV file first", gr.Dropdown()
88
-
89
- df, message = process_csv_file(csv_file)
90
 
91
  if df is None:
92
- return gr.LinePlot(), message, gr.Dropdown()
93
-
94
- # Add index as x-axis
95
- df_with_index = df.copy()
96
- # df_with_index["_chatts_internal_index"] = np.
97
 
98
  # Create dropdown choices
99
  column_choices = list(df.columns)
100
 
101
  # Create plot with first column as default
102
  first_column = column_choices[0]
 
103
  df_with_index["_internal_idx"] = np.arange(len(df[first_column].values))
104
  plot = gr.LinePlot(
105
  df_with_index,
@@ -115,16 +124,16 @@ def preview_csv(csv_file):
115
  label="Select Time Series"
116
  )
117
 
118
- print("Successfully generated preview_csv!")
119
 
120
  return plot, message, dropdown
121
 
122
  def update_plot(csv_file, selected_column):
123
  """Update plot based on selected column"""
124
- if csv_file is None or selected_column is None:
125
- return gr.LinePlot()
126
 
127
- df, _ = process_csv_file(csv_file)
128
  if df is None:
129
  return gr.LinePlot(value=pd.DataFrame())
130
 
@@ -140,6 +149,32 @@ def update_plot(csv_file, selected_column):
140
 
141
  return plot
142
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  # ─── INFERENCE + VALIDATION ────────────────────────────────────────────────────
144
 
145
  @spaces.GPU # dynamically allocate & release a ZeroGPU device on each call
@@ -148,16 +183,13 @@ def infer_chatts_stream(prompt: str, csv_file):
148
  Streaming version of ChatTS inference
149
  """
150
  print("Start inferring!!!")
151
- if csv_file is None:
152
- yield "Please upload a CSV file first"
153
- return
154
 
155
  if not prompt.strip():
156
  yield "Please enter a prompt"
157
  return
158
 
159
- # Process CSV file
160
- df, error_msg = process_csv_file(csv_file)
161
  if df is None:
162
  yield f"Error: {error_msg}"
163
  return
@@ -210,20 +242,6 @@ def infer_chatts_stream(prompt: str, csv_file):
210
  model_output += new_text
211
  yield model_output
212
 
213
- # # Decode the generated text
214
- # full_generated = tokenizer.decode(
215
- # outputs[0][inputs["input_ids"].shape[-1]:],
216
- # skip_special_tokens=True
217
- # )
218
-
219
- # # Simulate streaming by yielding character by character
220
- # for i, char in enumerate(full_generated):
221
- # generated_text += char
222
- # if i % 5 == 0: # Update every 5 characters for smoother streaming
223
- # yield generated_text
224
-
225
- # yield generated_text
226
-
227
  except Exception as e:
228
  yield f"Error during inference: {str(e)}"
229
 
@@ -231,7 +249,7 @@ def infer_chatts_stream(prompt: str, csv_file):
231
 
232
  with gr.Blocks(title="ChatTS Demo") as demo:
233
  gr.Markdown("## ChatTS Demo: Time Series Understanding and Reasoning")
234
- gr.Markdown("Upload a CSV file where each column is a time series. All columns will be treated as time series data.")
235
 
236
  with gr.Row():
237
  with gr.Column(scale=1):
@@ -245,7 +263,7 @@ with gr.Blocks(title="ChatTS Demo") as demo:
245
  lines=4,
246
  placeholder="Enter your analysis prompt here...",
247
  label="Analysis Prompt",
248
- value="Please analyze these time series and provide insights about trends, patterns, and anomalies."
249
  )
250
 
251
  run_btn = gr.Button("Run ChatTS", variant="primary")
@@ -263,13 +281,18 @@ with gr.Blocks(title="ChatTS Demo") as demo:
263
  lines=2
264
  )
265
 
266
-
267
  text_out = gr.Textbox(
268
  lines=10,
269
  label="ChatTS Analysis Results",
270
  interactive=False
271
  )
272
 
 
 
 
 
 
 
273
  # Event handlers
274
  upload.upload(
275
  fn=preview_csv,
 
31
 
32
  # ─── HELPER FUNCTIONS ──────────────────────────────────────────────────────────
33
 
34
+ def create_default_timeseries():
35
+ """Create default time series with sudden increase"""
36
+ seq_len = 256
37
+ y = np.zeros(seq_len, dtype=np.float32)
38
+ y[100:] += 1
39
+ df = pd.DataFrame({"default_series": y})
40
+ return df
41
+
42
  def process_csv_file(csv_file):
43
  """Process CSV file and return DataFrame with validation"""
44
  if csv_file is None:
 
89
  except Exception as e:
90
  return None, f"Error processing file: {str(e)}"
91
 
92
+ def get_current_data(csv_file):
93
+ """Get current data (either uploaded CSV or default)"""
94
+ if csv_file is None:
95
+ return create_default_timeseries(), "Using default time series with sudden increase at step 100"
96
+ else:
97
+ return process_csv_file(csv_file)
98
+
99
  def preview_csv(csv_file):
100
  """Preview uploaded CSV file immediately"""
101
+ df, message = get_current_data(csv_file)
 
 
 
102
 
103
  if df is None:
104
+ return gr.LinePlot(value=pd.DataFrame()), message, gr.Dropdown()
 
 
 
 
105
 
106
  # Create dropdown choices
107
  column_choices = list(df.columns)
108
 
109
  # Create plot with first column as default
110
  first_column = column_choices[0]
111
+ df_with_index = df.copy()
112
  df_with_index["_internal_idx"] = np.arange(len(df[first_column].values))
113
  plot = gr.LinePlot(
114
  df_with_index,
 
124
  label="Select Time Series"
125
  )
126
 
127
+ print("Successfully generated preview!")
128
 
129
  return plot, message, dropdown
130
 
131
  def update_plot(csv_file, selected_column):
132
  """Update plot based on selected column"""
133
+ if selected_column is None:
134
+ return gr.LinePlot(value=pd.DataFrame())
135
 
136
+ df, _ = get_current_data(csv_file)
137
  if df is None:
138
  return gr.LinePlot(value=pd.DataFrame())
139
 
 
149
 
150
  return plot
151
 
152
+ def initialize_interface():
153
+ """Initialize interface with default time series"""
154
+ df = create_default_timeseries()
155
+ column_choices = list(df.columns)
156
+ first_column = column_choices[0]
157
+
158
+ df_with_index = df.copy()
159
+ df_with_index["_internal_idx"] = np.arange(len(df[first_column].values))
160
+
161
+ plot = gr.LinePlot(
162
+ df_with_index,
163
+ x="_internal_idx",
164
+ y=first_column,
165
+ title=f"Time Series: {first_column}"
166
+ )
167
+
168
+ dropdown = gr.Dropdown(
169
+ choices=column_choices,
170
+ value=first_column,
171
+ label="Select Time Series"
172
+ )
173
+
174
+ message = "Using default time series with sudden increase at step 100"
175
+
176
+ return plot, message, dropdown
177
+
178
  # ─── INFERENCE + VALIDATION ────────────────────────────────────────────────────
179
 
180
  @spaces.GPU # dynamically allocate & release a ZeroGPU device on each call
 
183
  Streaming version of ChatTS inference
184
  """
185
  print("Start inferring!!!")
 
 
 
186
 
187
  if not prompt.strip():
188
  yield "Please enter a prompt"
189
  return
190
 
191
+ # Get current data (CSV or default)
192
+ df, error_msg = get_current_data(csv_file)
193
  if df is None:
194
  yield f"Error: {error_msg}"
195
  return
 
242
  model_output += new_text
243
  yield model_output
244
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
  except Exception as e:
246
  yield f"Error during inference: {str(e)}"
247
 
 
249
 
250
  with gr.Blocks(title="ChatTS Demo") as demo:
251
  gr.Markdown("## ChatTS Demo: Time Series Understanding and Reasoning")
252
+ gr.Markdown("Upload a CSV file where each column is a time series, or use the default time series with sudden increase. All columns will be treated as time series data.")
253
 
254
  with gr.Row():
255
  with gr.Column(scale=1):
 
263
  lines=4,
264
  placeholder="Enter your analysis prompt here...",
265
  label="Analysis Prompt",
266
+ value="Please analyze this time series and provide insights about the trends, seasonality, and local fluctuations."
267
  )
268
 
269
  run_btn = gr.Button("Run ChatTS", variant="primary")
 
281
  lines=2
282
  )
283
 
 
284
  text_out = gr.Textbox(
285
  lines=10,
286
  label="ChatTS Analysis Results",
287
  interactive=False
288
  )
289
 
290
+ # Initialize interface with default data
291
+ demo.load(
292
+ fn=initialize_interface,
293
+ outputs=[plot_out, file_status, series_selector]
294
+ )
295
+
296
  # Event handlers
297
  upload.upload(
298
  fn=preview_csv,