xiezhe22 commited on
Commit
77babc6
Β·
1 Parent(s): da97255
Files changed (2) hide show
  1. app.py +15 -103
  2. app_chatts.py +141 -0
app.py CHANGED
@@ -29,111 +29,23 @@ model.eval()
29
 
30
 
31
  # ─── INFERENCE + VALIDATION ────────────────────────────────────────────────────
32
-
33
- @spaces.GPU # dynamically allocate & release a ZeroGPU device on each call :contentReference[oaicite:0]{index=0}
34
- def infer_chatts(prompt: str, csv_file):
35
- """
36
- 1. Load CSV: first column = timestamp, other columns = TS names.
37
- 2. Drop empty / all-NaN columns; enforce <=15 series.
38
- 3. For each series: trim trailing NaNs, enforce 64 ≀ length ≀ 1024.
39
- 4. Build prompt prefix as:
40
- I have N time series:
41
- The <name> is of length L: <ts><ts/>
42
- 5. Encode & generate (max_new_tokens=512).
43
- 6. Return Gradio LinePlot + generated text.
44
- """
45
- # β€”β€”β€” CSV LOADING & PREPROCESSING β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
46
- df = pd.read_csv(csv_file.name, parse_dates=True, index_col=0)
47
-
48
- # drop columns with empty names or all-NaNs
49
- df.columns = [str(c).strip() for c in df.columns]
50
- df = df.loc[:, [c for c in df.columns if c]]
51
- df = df.dropna(axis=1, how="all")
52
-
53
- if df.shape[1] == 0:
54
- raise gr.Error("No valid time-series columns found.")
55
- if df.shape[1] > 15:
56
- raise gr.Error(f"Too many series ({df.shape[1]}). Max allowed = 15.")
57
-
58
- ts_names, ts_list = [], []
59
- for name in df.columns:
60
- series = df[name]
61
- # ensure float dtype
62
- if not pd.api.types.is_float_dtype(series):
63
- raise gr.Error(f"Series '{name}' must be float type.")
64
- # trim trailing NaNs only
65
- last_valid = series.last_valid_index()
66
- if last_valid is None:
67
- continue
68
- trimmed = series.loc[:last_valid].to_numpy(dtype=np.float32)
69
- length = trimmed.shape[0]
70
- if length < 64 or length > 1024:
71
- raise gr.Error(
72
- f"Series '{name}' length {length} invalid. Must be 64 to 1024."
73
- )
74
- ts_names.append(name)
75
- ts_list.append(trimmed)
76
-
77
- if not ts_list:
78
- raise gr.Error("All series are empty after trimming NaNs.")
79
-
80
- # β€”β€”β€” BUILD PROMPT β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
81
- prefix = f"I have {len(ts_list)} time series:\n"
82
- for name, arr in zip(ts_names, ts_list):
83
- prefix += f"The {name} is of length {len(arr)}: <ts><ts/>\n"
84
- full_prompt = prefix + prompt
85
-
86
- # β€”β€”β€” ENCODE & GENERATE β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
87
- inputs = processor(
88
- text=[full_prompt],
89
- timeseries=ts_list,
90
- padding=True,
91
- return_tensors="pt"
92
  )
93
- inputs = {k: v.to(model.device) for k, v in inputs.items()}
94
 
95
- outputs = model.generate(**inputs, max_new_tokens=512)
96
- generated = tokenizer.decode(
97
- outputs[0][ inputs["input_ids"].shape[-1] : ],
98
- skip_special_tokens=True
99
- )
100
-
101
- # β€”β€”β€” VISUALIZATION β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
102
- plot = gr.LinePlot(
103
- df.reset_index(),
104
- x=df.index.name or df.reset_index().columns[0],
105
- y=ts_names,
106
- label="Uploaded Time Series"
107
- )
108
-
109
- return plot, generated
110
-
111
-
112
- # ─── GRADIO APP ────────────────────────────────────────────────────────────────
113
-
114
- with gr.Blocks() as demo:
115
- gr.Markdown("## ChatTS: Text + Time Series Inference Demo")
116
-
117
- with gr.Row():
118
- prompt_input = gr.Textbox(
119
- lines=3,
120
- placeholder="Enter your analysis prompt…",
121
- label="Prompt"
122
- )
123
- upload = gr.UploadButton(
124
- "Upload CSV (timestamp + float columns)",
125
- file_types=[".csv"]
126
- )
127
-
128
- plot_out = gr.LinePlot(label="Time Series Visualization")
129
- text_out = gr.Textbox(lines=8, label="Model Response")
130
-
131
- run_btn = gr.Button("Run ChatTS")
132
- run_btn.click(
133
- fn=infer_chatts,
134
- inputs=[prompt_input, upload],
135
- outputs=[plot_out, text_out]
136
- )
137
 
138
 
139
  if __name__ == '__main__':
 
29
 
30
 
31
  # ─── INFERENCE + VALIDATION ────────────────────────────────────────────────────
32
+ @spaces.gpu
33
+ def generate_text(prompt):
34
+ inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
35
+ outputs = model.generate(
36
+ **inputs,
37
+ max_new_tokens=512,
38
+ do_sample=True,
39
+ temperature=0.2,
40
+ top_p=0.9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  )
42
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
43
 
44
+ demo = gr.Interface(
45
+ fn=generate_text,
46
+ inputs=gr.Textbox(lines=2, label="Prompt"),
47
+ outputs=gr.Textbox(lines=6, label="Generated Text")
48
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
 
51
  if __name__ == '__main__':
app_chatts.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces # for ZeroGPU support
2
+ import gradio as gr
3
+ import pandas as pd
4
+ import numpy as np
5
+ import torch
6
+ import subprocess
7
+ from transformers import (
8
+ AutoModelForCausalLM,
9
+ AutoTokenizer,
10
+ AutoProcessor,
11
+ )
12
+
13
+ # ─── MODEL SETUP ────────────────────────────────────────────────────────────────
14
+ MODEL_NAME = "bytedance-research/ChatTS-14B"
15
+
16
+ tokenizer = AutoTokenizer.from_pretrained(
17
+ MODEL_NAME, trust_remote_code=True
18
+ )
19
+ processor = AutoProcessor.from_pretrained(
20
+ MODEL_NAME, trust_remote_code=True, tokenizer=tokenizer
21
+ )
22
+ model = AutoModelForCausalLM.from_pretrained(
23
+ MODEL_NAME,
24
+ trust_remote_code=True,
25
+ device_map="auto",
26
+ torch_dtype=torch.float16
27
+ )
28
+ model.eval()
29
+
30
+
31
+ # ─── INFERENCE + VALIDATION ────────────────────────────────────────────────────
32
+
33
+ @spaces.GPU # dynamically allocate & release a ZeroGPU device on each call :contentReference[oaicite:0]{index=0}
34
+ def infer_chatts(prompt: str, csv_file):
35
+ """
36
+ 1. Load CSV: first column = timestamp, other columns = TS names.
37
+ 2. Drop empty / all-NaN columns; enforce <=15 series.
38
+ 3. For each series: trim trailing NaNs, enforce 64 ≀ length ≀ 1024.
39
+ 4. Build prompt prefix as:
40
+ I have N time series:
41
+ The <name> is of length L: <ts><ts/>
42
+ 5. Encode & generate (max_new_tokens=512).
43
+ 6. Return Gradio LinePlot + generated text.
44
+ """
45
+ # β€”β€”β€” CSV LOADING & PREPROCESSING β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
46
+ df = pd.read_csv(csv_file.name, parse_dates=True, index_col=0)
47
+
48
+ # drop columns with empty names or all-NaNs
49
+ df.columns = [str(c).strip() for c in df.columns]
50
+ df = df.loc[:, [c for c in df.columns if c]]
51
+ df = df.dropna(axis=1, how="all")
52
+
53
+ if df.shape[1] == 0:
54
+ raise gr.Error("No valid time-series columns found.")
55
+ if df.shape[1] > 15:
56
+ raise gr.Error(f"Too many series ({df.shape[1]}). Max allowed = 15.")
57
+
58
+ ts_names, ts_list = [], []
59
+ for name in df.columns:
60
+ series = df[name]
61
+ # ensure float dtype
62
+ if not pd.api.types.is_float_dtype(series):
63
+ raise gr.Error(f"Series '{name}' must be float type.")
64
+ # trim trailing NaNs only
65
+ last_valid = series.last_valid_index()
66
+ if last_valid is None:
67
+ continue
68
+ trimmed = series.loc[:last_valid].to_numpy(dtype=np.float32)
69
+ length = trimmed.shape[0]
70
+ if length < 64 or length > 1024:
71
+ raise gr.Error(
72
+ f"Series '{name}' length {length} invalid. Must be 64 to 1024."
73
+ )
74
+ ts_names.append(name)
75
+ ts_list.append(trimmed)
76
+
77
+ if not ts_list:
78
+ raise gr.Error("All series are empty after trimming NaNs.")
79
+
80
+ # β€”β€”β€” BUILD PROMPT β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
81
+ prefix = f"I have {len(ts_list)} time series:\n"
82
+ for name, arr in zip(ts_names, ts_list):
83
+ prefix += f"The {name} is of length {len(arr)}: <ts><ts/>\n"
84
+ full_prompt = prefix + prompt
85
+
86
+ # β€”β€”β€” ENCODE & GENERATE β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
87
+ inputs = processor(
88
+ text=[full_prompt],
89
+ timeseries=ts_list,
90
+ padding=True,
91
+ return_tensors="pt"
92
+ )
93
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
94
+
95
+ outputs = model.generate(**inputs, max_new_tokens=512)
96
+ generated = tokenizer.decode(
97
+ outputs[0][ inputs["input_ids"].shape[-1] : ],
98
+ skip_special_tokens=True
99
+ )
100
+
101
+ # β€”β€”β€” VISUALIZATION β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
102
+ plot = gr.LinePlot(
103
+ df.reset_index(),
104
+ x=df.index.name or df.reset_index().columns[0],
105
+ y=ts_names,
106
+ label="Uploaded Time Series"
107
+ )
108
+
109
+ return plot, generated
110
+
111
+
112
+ # ─── GRADIO APP ────────────────────────────────────────────────────────────────
113
+
114
+ with gr.Blocks() as demo:
115
+ gr.Markdown("## ChatTS: Text + Time Series Inference Demo")
116
+
117
+ with gr.Row():
118
+ prompt_input = gr.Textbox(
119
+ lines=3,
120
+ placeholder="Enter your analysis prompt…",
121
+ label="Prompt"
122
+ )
123
+ upload = gr.UploadButton(
124
+ "Upload CSV (timestamp + float columns)",
125
+ file_types=[".csv"]
126
+ )
127
+
128
+ plot_out = gr.LinePlot(label="Time Series Visualization")
129
+ text_out = gr.Textbox(lines=8, label="Model Response")
130
+
131
+ run_btn = gr.Button("Run ChatTS")
132
+ run_btn.click(
133
+ fn=infer_chatts,
134
+ inputs=[prompt_input, upload],
135
+ outputs=[plot_out, text_out]
136
+ )
137
+
138
+
139
+ if __name__ == '__main__':
140
+ subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True)
141
+ demo.launch()