xiezhe22 commited on
Commit
b69bd77
Β·
1 Parent(s): 009e745

Update ChatTS

Browse files
Files changed (3) hide show
  1. app.py +107 -15
  2. app_chatts.py +0 -141
  3. app_legacy.py +53 -0
app.py CHANGED
@@ -29,23 +29,115 @@ model.eval()
29
 
30
 
31
  # ─── INFERENCE + VALIDATION ────────────────────────────────────────────────────
32
- @spaces.GPU
33
- def generate_text(prompt):
34
- inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
35
- outputs = model.generate(
36
- **inputs,
37
- max_new_tokens=512,
38
- do_sample=True,
39
- temperature=0.2,
40
- top_p=0.9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  )
42
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
43
 
44
- demo = gr.Interface(
45
- fn=generate_text,
46
- inputs=gr.Textbox(lines=2, label="Prompt"),
47
- outputs=gr.Textbox(lines=6, label="Generated Text")
48
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
 
51
  if __name__ == '__main__':
 
29
 
30
 
31
  # ─── INFERENCE + VALIDATION ────────────────────────────────────────────────────
32
+
33
+ @spaces.GPU # dynamically allocate & release a ZeroGPU device on each call :contentReference[oaicite:0]{index=0}
34
+ def infer_chatts(prompt: str, csv_file):
35
+ """
36
+ 1. Load CSV: first column = timestamp, other columns = TS names.
37
+ 2. Drop empty / all-NaN columns; enforce <=15 series.
38
+ 3. For each series: trim trailing NaNs, enforce 64 ≀ length ≀ 1024.
39
+ 4. Build prompt prefix as:
40
+ I have N time series:
41
+ The <name> is of length L: <ts><ts/>
42
+ 5. Encode & generate (max_new_tokens=512).
43
+ 6. Return Gradio LinePlot + generated text.
44
+ """
45
+ # Add chat_template
46
+ prompt = prompt.replace("<ts>", "").replace("<ts/>", "")
47
+ prompt = f"<|im_start|>system\nYou are a helpful assistant. Your name is ChatTS. You can analyze time series data and provide insights. If user asks who you are, you should give your name and capabilities in the language of the prompt. If no time series are provided, you should say 'I cannot answer this question as you haven't provide the timeseries I need' in the language of the prompt. Always check if the user has provided time series data before answering.<|im_end|><|im_start|>user\n{prompt}<|im_end|><|im_start|>assistant\n"
48
+
49
+ # β€”β€”β€” CSV LOADING & PREPROCESSING β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
50
+ df = pd.read_csv(csv_file.name, parse_dates=True, index_col=0)
51
+
52
+ # drop columns with empty names or all-NaNs
53
+ df.columns = [str(c).strip() for c in df.columns]
54
+ df = df.loc[:, [c for c in df.columns if c]]
55
+ df = df.dropna(axis=1, how="all")
56
+
57
+ if df.shape[1] == 0:
58
+ raise gr.Error("No valid time-series columns found.")
59
+ if df.shape[1] > 15:
60
+ raise gr.Error(f"Too many series ({df.shape[1]}). Max allowed = 15.")
61
+
62
+ ts_names, ts_list = [], []
63
+ for name in df.columns:
64
+ series = df[name]
65
+ # ensure float dtype
66
+ if not pd.api.types.is_float_dtype(series):
67
+ raise gr.Error(f"Series '{name}' must be float type.")
68
+ # trim trailing NaNs only
69
+ last_valid = series.last_valid_index()
70
+ if last_valid is None:
71
+ continue
72
+ trimmed = series.loc[:last_valid].to_numpy(dtype=np.float32)
73
+ length = trimmed.shape[0]
74
+ if length < 64 or length > 1024:
75
+ raise gr.Error(
76
+ f"Series '{name}' length {length} invalid. Must be 64 to 1024."
77
+ )
78
+ ts_names.append(name)
79
+ ts_list.append(trimmed)
80
+
81
+ if not ts_list:
82
+ raise gr.Error("All series are empty after trimming NaNs.")
83
+
84
+ # β€”β€”β€” BUILD PROMPT β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
85
+ prefix = f"I have {len(ts_list)} time series:\n"
86
+ for name, arr in zip(ts_names, ts_list):
87
+ prefix += f"The {name} is of length {len(arr)}: <ts><ts/>\n"
88
+ full_prompt = prefix + prompt
89
+
90
+ # β€”β€”β€” ENCODE & GENERATE β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
91
+ inputs = processor(
92
+ text=[full_prompt],
93
+ timeseries=ts_list,
94
+ padding=True,
95
+ return_tensors="pt"
96
  )
97
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
98
 
99
+ outputs = model.generate(**inputs, max_new_tokens=512)
100
+ generated = tokenizer.decode(
101
+ outputs[0][ inputs["input_ids"].shape[-1] : ],
102
+ skip_special_tokens=True
103
+ )
104
+
105
+ # β€”β€”β€” VISUALIZATION β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
106
+ plot = gr.LinePlot(
107
+ df.reset_index(),
108
+ x=df.index.name or df.reset_index().columns[0],
109
+ y=ts_names,
110
+ label="Uploaded Time Series"
111
+ )
112
+
113
+ return plot, generated
114
+
115
+
116
+ # ─── GRADIO APP ────────────────────────────────────────────────────────────────
117
+
118
+ with gr.Blocks() as demo:
119
+ gr.Markdown("## ChatTS: Text + Time Series Inference Demo")
120
+
121
+ with gr.Row():
122
+ prompt_input = gr.Textbox(
123
+ lines=3,
124
+ placeholder="Enter your analysis prompt…",
125
+ label="Prompt"
126
+ )
127
+ upload = gr.UploadButton(
128
+ "Upload CSV (timestamp + float columns)",
129
+ file_types=[".csv"]
130
+ )
131
+
132
+ plot_out = gr.LinePlot(label="Time Series Visualization")
133
+ text_out = gr.Textbox(lines=8, label="Model Response")
134
+
135
+ run_btn = gr.Button("Run ChatTS")
136
+ run_btn.click(
137
+ fn=infer_chatts,
138
+ inputs=[prompt_input, upload],
139
+ outputs=[plot_out, text_out]
140
+ )
141
 
142
 
143
  if __name__ == '__main__':
app_chatts.py DELETED
@@ -1,141 +0,0 @@
1
- import spaces # for ZeroGPU support
2
- import gradio as gr
3
- import pandas as pd
4
- import numpy as np
5
- import torch
6
- import subprocess
7
- from transformers import (
8
- AutoModelForCausalLM,
9
- AutoTokenizer,
10
- AutoProcessor,
11
- )
12
-
13
- # ─── MODEL SETUP ────────────────────────────────────────────────────────────────
14
- MODEL_NAME = "bytedance-research/ChatTS-14B"
15
-
16
- tokenizer = AutoTokenizer.from_pretrained(
17
- MODEL_NAME, trust_remote_code=True
18
- )
19
- processor = AutoProcessor.from_pretrained(
20
- MODEL_NAME, trust_remote_code=True, tokenizer=tokenizer
21
- )
22
- model = AutoModelForCausalLM.from_pretrained(
23
- MODEL_NAME,
24
- trust_remote_code=True,
25
- device_map="auto",
26
- torch_dtype=torch.float16
27
- )
28
- model.eval()
29
-
30
-
31
- # ─── INFERENCE + VALIDATION ────────────────────────────────────────────────────
32
-
33
- @spaces.GPU # dynamically allocate & release a ZeroGPU device on each call :contentReference[oaicite:0]{index=0}
34
- def infer_chatts(prompt: str, csv_file):
35
- """
36
- 1. Load CSV: first column = timestamp, other columns = TS names.
37
- 2. Drop empty / all-NaN columns; enforce <=15 series.
38
- 3. For each series: trim trailing NaNs, enforce 64 ≀ length ≀ 1024.
39
- 4. Build prompt prefix as:
40
- I have N time series:
41
- The <name> is of length L: <ts><ts/>
42
- 5. Encode & generate (max_new_tokens=512).
43
- 6. Return Gradio LinePlot + generated text.
44
- """
45
- # β€”β€”β€” CSV LOADING & PREPROCESSING β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
46
- df = pd.read_csv(csv_file.name, parse_dates=True, index_col=0)
47
-
48
- # drop columns with empty names or all-NaNs
49
- df.columns = [str(c).strip() for c in df.columns]
50
- df = df.loc[:, [c for c in df.columns if c]]
51
- df = df.dropna(axis=1, how="all")
52
-
53
- if df.shape[1] == 0:
54
- raise gr.Error("No valid time-series columns found.")
55
- if df.shape[1] > 15:
56
- raise gr.Error(f"Too many series ({df.shape[1]}). Max allowed = 15.")
57
-
58
- ts_names, ts_list = [], []
59
- for name in df.columns:
60
- series = df[name]
61
- # ensure float dtype
62
- if not pd.api.types.is_float_dtype(series):
63
- raise gr.Error(f"Series '{name}' must be float type.")
64
- # trim trailing NaNs only
65
- last_valid = series.last_valid_index()
66
- if last_valid is None:
67
- continue
68
- trimmed = series.loc[:last_valid].to_numpy(dtype=np.float32)
69
- length = trimmed.shape[0]
70
- if length < 64 or length > 1024:
71
- raise gr.Error(
72
- f"Series '{name}' length {length} invalid. Must be 64 to 1024."
73
- )
74
- ts_names.append(name)
75
- ts_list.append(trimmed)
76
-
77
- if not ts_list:
78
- raise gr.Error("All series are empty after trimming NaNs.")
79
-
80
- # β€”β€”β€” BUILD PROMPT β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
81
- prefix = f"I have {len(ts_list)} time series:\n"
82
- for name, arr in zip(ts_names, ts_list):
83
- prefix += f"The {name} is of length {len(arr)}: <ts><ts/>\n"
84
- full_prompt = prefix + prompt
85
-
86
- # β€”β€”β€” ENCODE & GENERATE β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
87
- inputs = processor(
88
- text=[full_prompt],
89
- timeseries=ts_list,
90
- padding=True,
91
- return_tensors="pt"
92
- )
93
- inputs = {k: v.to(model.device) for k, v in inputs.items()}
94
-
95
- outputs = model.generate(**inputs, max_new_tokens=512)
96
- generated = tokenizer.decode(
97
- outputs[0][ inputs["input_ids"].shape[-1] : ],
98
- skip_special_tokens=True
99
- )
100
-
101
- # β€”β€”β€” VISUALIZATION β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
102
- plot = gr.LinePlot(
103
- df.reset_index(),
104
- x=df.index.name or df.reset_index().columns[0],
105
- y=ts_names,
106
- label="Uploaded Time Series"
107
- )
108
-
109
- return plot, generated
110
-
111
-
112
- # ─── GRADIO APP ────────────────────────────────────────────────────────────────
113
-
114
- with gr.Blocks() as demo:
115
- gr.Markdown("## ChatTS: Text + Time Series Inference Demo")
116
-
117
- with gr.Row():
118
- prompt_input = gr.Textbox(
119
- lines=3,
120
- placeholder="Enter your analysis prompt…",
121
- label="Prompt"
122
- )
123
- upload = gr.UploadButton(
124
- "Upload CSV (timestamp + float columns)",
125
- file_types=[".csv"]
126
- )
127
-
128
- plot_out = gr.LinePlot(label="Time Series Visualization")
129
- text_out = gr.Textbox(lines=8, label="Model Response")
130
-
131
- run_btn = gr.Button("Run ChatTS")
132
- run_btn.click(
133
- fn=infer_chatts,
134
- inputs=[prompt_input, upload],
135
- outputs=[plot_out, text_out]
136
- )
137
-
138
-
139
- if __name__ == '__main__':
140
- subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True)
141
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app_legacy.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces # for ZeroGPU support
2
+ import gradio as gr
3
+ import pandas as pd
4
+ import numpy as np
5
+ import torch
6
+ import subprocess
7
+ from transformers import (
8
+ AutoModelForCausalLM,
9
+ AutoTokenizer,
10
+ AutoProcessor,
11
+ )
12
+
13
+ # ─── MODEL SETUP ────────────────────────────────────────────────────────────────
14
+ MODEL_NAME = "bytedance-research/ChatTS-14B"
15
+
16
+ tokenizer = AutoTokenizer.from_pretrained(
17
+ MODEL_NAME, trust_remote_code=True
18
+ )
19
+ processor = AutoProcessor.from_pretrained(
20
+ MODEL_NAME, trust_remote_code=True, tokenizer=tokenizer
21
+ )
22
+ model = AutoModelForCausalLM.from_pretrained(
23
+ MODEL_NAME,
24
+ trust_remote_code=True,
25
+ device_map="auto",
26
+ torch_dtype=torch.float16
27
+ )
28
+ model.eval()
29
+
30
+
31
+ # ─── INFERENCE + VALIDATION ────────────────────────────────────────────────────
32
+ @spaces.GPU
33
+ def generate_text(prompt):
34
+ inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
35
+ outputs = model.generate(
36
+ **inputs,
37
+ max_new_tokens=512,
38
+ do_sample=True,
39
+ temperature=0.2,
40
+ top_p=0.9
41
+ )
42
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
43
+
44
+ demo = gr.Interface(
45
+ fn=generate_text,
46
+ inputs=gr.Textbox(lines=2, label="Prompt"),
47
+ outputs=gr.Textbox(lines=6, label="Generated Text")
48
+ )
49
+
50
+
51
+ if __name__ == '__main__':
52
+ subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True)
53
+ demo.launch()