Update ChatTS
Browse files- app.py +107 -15
- app_chatts.py +0 -141
- app_legacy.py +53 -0
app.py
CHANGED
@@ -29,23 +29,115 @@ model.eval()
|
|
29 |
|
30 |
|
31 |
# βββ INFERENCE + VALIDATION ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
)
|
42 |
-
|
43 |
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
|
50 |
|
51 |
if __name__ == '__main__':
|
|
|
29 |
|
30 |
|
31 |
# βββ INFERENCE + VALIDATION ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
32 |
+
|
33 |
+
@spaces.GPU # dynamically allocate & release a ZeroGPU device on each call :contentReference[oaicite:0]{index=0}
|
34 |
+
def infer_chatts(prompt: str, csv_file):
|
35 |
+
"""
|
36 |
+
1. Load CSV: first column = timestamp, other columns = TS names.
|
37 |
+
2. Drop empty / all-NaN columns; enforce <=15 series.
|
38 |
+
3. For each series: trim trailing NaNs, enforce 64 β€ length β€ 1024.
|
39 |
+
4. Build prompt prefix as:
|
40 |
+
I have N time series:
|
41 |
+
The <name> is of length L: <ts><ts/>
|
42 |
+
5. Encode & generate (max_new_tokens=512).
|
43 |
+
6. Return Gradio LinePlot + generated text.
|
44 |
+
"""
|
45 |
+
# Add chat_template
|
46 |
+
prompt = prompt.replace("<ts>", "").replace("<ts/>", "")
|
47 |
+
prompt = f"<|im_start|>system\nYou are a helpful assistant. Your name is ChatTS. You can analyze time series data and provide insights. If user asks who you are, you should give your name and capabilities in the language of the prompt. If no time series are provided, you should say 'I cannot answer this question as you haven't provide the timeseries I need' in the language of the prompt. Always check if the user has provided time series data before answering.<|im_end|><|im_start|>user\n{prompt}<|im_end|><|im_start|>assistant\n"
|
48 |
+
|
49 |
+
# βββ CSV LOADING & PREPROCESSING ββββββββββββββββββββββββββββββ
|
50 |
+
df = pd.read_csv(csv_file.name, parse_dates=True, index_col=0)
|
51 |
+
|
52 |
+
# drop columns with empty names or all-NaNs
|
53 |
+
df.columns = [str(c).strip() for c in df.columns]
|
54 |
+
df = df.loc[:, [c for c in df.columns if c]]
|
55 |
+
df = df.dropna(axis=1, how="all")
|
56 |
+
|
57 |
+
if df.shape[1] == 0:
|
58 |
+
raise gr.Error("No valid time-series columns found.")
|
59 |
+
if df.shape[1] > 15:
|
60 |
+
raise gr.Error(f"Too many series ({df.shape[1]}). Max allowed = 15.")
|
61 |
+
|
62 |
+
ts_names, ts_list = [], []
|
63 |
+
for name in df.columns:
|
64 |
+
series = df[name]
|
65 |
+
# ensure float dtype
|
66 |
+
if not pd.api.types.is_float_dtype(series):
|
67 |
+
raise gr.Error(f"Series '{name}' must be float type.")
|
68 |
+
# trim trailing NaNs only
|
69 |
+
last_valid = series.last_valid_index()
|
70 |
+
if last_valid is None:
|
71 |
+
continue
|
72 |
+
trimmed = series.loc[:last_valid].to_numpy(dtype=np.float32)
|
73 |
+
length = trimmed.shape[0]
|
74 |
+
if length < 64 or length > 1024:
|
75 |
+
raise gr.Error(
|
76 |
+
f"Series '{name}' length {length} invalid. Must be 64 to 1024."
|
77 |
+
)
|
78 |
+
ts_names.append(name)
|
79 |
+
ts_list.append(trimmed)
|
80 |
+
|
81 |
+
if not ts_list:
|
82 |
+
raise gr.Error("All series are empty after trimming NaNs.")
|
83 |
+
|
84 |
+
# βββ BUILD PROMPT ββββββββββββββββββββββββββββββββββββββββββββββ
|
85 |
+
prefix = f"I have {len(ts_list)} time series:\n"
|
86 |
+
for name, arr in zip(ts_names, ts_list):
|
87 |
+
prefix += f"The {name} is of length {len(arr)}: <ts><ts/>\n"
|
88 |
+
full_prompt = prefix + prompt
|
89 |
+
|
90 |
+
# βββ ENCODE & GENERATE ββββββββββββββββββββββββββββββββββββββββ
|
91 |
+
inputs = processor(
|
92 |
+
text=[full_prompt],
|
93 |
+
timeseries=ts_list,
|
94 |
+
padding=True,
|
95 |
+
return_tensors="pt"
|
96 |
)
|
97 |
+
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
98 |
|
99 |
+
outputs = model.generate(**inputs, max_new_tokens=512)
|
100 |
+
generated = tokenizer.decode(
|
101 |
+
outputs[0][ inputs["input_ids"].shape[-1] : ],
|
102 |
+
skip_special_tokens=True
|
103 |
+
)
|
104 |
+
|
105 |
+
# βββ VISUALIZATION ββββββββββββββββββββββββββββββββββββββββββββ
|
106 |
+
plot = gr.LinePlot(
|
107 |
+
df.reset_index(),
|
108 |
+
x=df.index.name or df.reset_index().columns[0],
|
109 |
+
y=ts_names,
|
110 |
+
label="Uploaded Time Series"
|
111 |
+
)
|
112 |
+
|
113 |
+
return plot, generated
|
114 |
+
|
115 |
+
|
116 |
+
# βββ GRADIO APP ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
117 |
+
|
118 |
+
with gr.Blocks() as demo:
|
119 |
+
gr.Markdown("## ChatTS: Text + Time Series Inference Demo")
|
120 |
+
|
121 |
+
with gr.Row():
|
122 |
+
prompt_input = gr.Textbox(
|
123 |
+
lines=3,
|
124 |
+
placeholder="Enter your analysis promptβ¦",
|
125 |
+
label="Prompt"
|
126 |
+
)
|
127 |
+
upload = gr.UploadButton(
|
128 |
+
"Upload CSV (timestamp + float columns)",
|
129 |
+
file_types=[".csv"]
|
130 |
+
)
|
131 |
+
|
132 |
+
plot_out = gr.LinePlot(label="Time Series Visualization")
|
133 |
+
text_out = gr.Textbox(lines=8, label="Model Response")
|
134 |
+
|
135 |
+
run_btn = gr.Button("Run ChatTS")
|
136 |
+
run_btn.click(
|
137 |
+
fn=infer_chatts,
|
138 |
+
inputs=[prompt_input, upload],
|
139 |
+
outputs=[plot_out, text_out]
|
140 |
+
)
|
141 |
|
142 |
|
143 |
if __name__ == '__main__':
|
app_chatts.py
DELETED
@@ -1,141 +0,0 @@
|
|
1 |
-
import spaces # for ZeroGPU support
|
2 |
-
import gradio as gr
|
3 |
-
import pandas as pd
|
4 |
-
import numpy as np
|
5 |
-
import torch
|
6 |
-
import subprocess
|
7 |
-
from transformers import (
|
8 |
-
AutoModelForCausalLM,
|
9 |
-
AutoTokenizer,
|
10 |
-
AutoProcessor,
|
11 |
-
)
|
12 |
-
|
13 |
-
# βββ MODEL SETUP ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
14 |
-
MODEL_NAME = "bytedance-research/ChatTS-14B"
|
15 |
-
|
16 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
17 |
-
MODEL_NAME, trust_remote_code=True
|
18 |
-
)
|
19 |
-
processor = AutoProcessor.from_pretrained(
|
20 |
-
MODEL_NAME, trust_remote_code=True, tokenizer=tokenizer
|
21 |
-
)
|
22 |
-
model = AutoModelForCausalLM.from_pretrained(
|
23 |
-
MODEL_NAME,
|
24 |
-
trust_remote_code=True,
|
25 |
-
device_map="auto",
|
26 |
-
torch_dtype=torch.float16
|
27 |
-
)
|
28 |
-
model.eval()
|
29 |
-
|
30 |
-
|
31 |
-
# βββ INFERENCE + VALIDATION ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
32 |
-
|
33 |
-
@spaces.GPU # dynamically allocate & release a ZeroGPU device on each call :contentReference[oaicite:0]{index=0}
|
34 |
-
def infer_chatts(prompt: str, csv_file):
|
35 |
-
"""
|
36 |
-
1. Load CSV: first column = timestamp, other columns = TS names.
|
37 |
-
2. Drop empty / all-NaN columns; enforce <=15 series.
|
38 |
-
3. For each series: trim trailing NaNs, enforce 64 β€ length β€ 1024.
|
39 |
-
4. Build prompt prefix as:
|
40 |
-
I have N time series:
|
41 |
-
The <name> is of length L: <ts><ts/>
|
42 |
-
5. Encode & generate (max_new_tokens=512).
|
43 |
-
6. Return Gradio LinePlot + generated text.
|
44 |
-
"""
|
45 |
-
# βββ CSV LOADING & PREPROCESSING ββββββββββββββββββββββββββββββ
|
46 |
-
df = pd.read_csv(csv_file.name, parse_dates=True, index_col=0)
|
47 |
-
|
48 |
-
# drop columns with empty names or all-NaNs
|
49 |
-
df.columns = [str(c).strip() for c in df.columns]
|
50 |
-
df = df.loc[:, [c for c in df.columns if c]]
|
51 |
-
df = df.dropna(axis=1, how="all")
|
52 |
-
|
53 |
-
if df.shape[1] == 0:
|
54 |
-
raise gr.Error("No valid time-series columns found.")
|
55 |
-
if df.shape[1] > 15:
|
56 |
-
raise gr.Error(f"Too many series ({df.shape[1]}). Max allowed = 15.")
|
57 |
-
|
58 |
-
ts_names, ts_list = [], []
|
59 |
-
for name in df.columns:
|
60 |
-
series = df[name]
|
61 |
-
# ensure float dtype
|
62 |
-
if not pd.api.types.is_float_dtype(series):
|
63 |
-
raise gr.Error(f"Series '{name}' must be float type.")
|
64 |
-
# trim trailing NaNs only
|
65 |
-
last_valid = series.last_valid_index()
|
66 |
-
if last_valid is None:
|
67 |
-
continue
|
68 |
-
trimmed = series.loc[:last_valid].to_numpy(dtype=np.float32)
|
69 |
-
length = trimmed.shape[0]
|
70 |
-
if length < 64 or length > 1024:
|
71 |
-
raise gr.Error(
|
72 |
-
f"Series '{name}' length {length} invalid. Must be 64 to 1024."
|
73 |
-
)
|
74 |
-
ts_names.append(name)
|
75 |
-
ts_list.append(trimmed)
|
76 |
-
|
77 |
-
if not ts_list:
|
78 |
-
raise gr.Error("All series are empty after trimming NaNs.")
|
79 |
-
|
80 |
-
# βββ BUILD PROMPT ββββββββββββββββββββββββββββββββββββββββββββββ
|
81 |
-
prefix = f"I have {len(ts_list)} time series:\n"
|
82 |
-
for name, arr in zip(ts_names, ts_list):
|
83 |
-
prefix += f"The {name} is of length {len(arr)}: <ts><ts/>\n"
|
84 |
-
full_prompt = prefix + prompt
|
85 |
-
|
86 |
-
# βββ ENCODE & GENERATE ββββββββββββββββββββββββββββββββββββββββ
|
87 |
-
inputs = processor(
|
88 |
-
text=[full_prompt],
|
89 |
-
timeseries=ts_list,
|
90 |
-
padding=True,
|
91 |
-
return_tensors="pt"
|
92 |
-
)
|
93 |
-
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
94 |
-
|
95 |
-
outputs = model.generate(**inputs, max_new_tokens=512)
|
96 |
-
generated = tokenizer.decode(
|
97 |
-
outputs[0][ inputs["input_ids"].shape[-1] : ],
|
98 |
-
skip_special_tokens=True
|
99 |
-
)
|
100 |
-
|
101 |
-
# βββ VISUALIZATION ββββββββββββββββββββββββββββββββββββββββββββ
|
102 |
-
plot = gr.LinePlot(
|
103 |
-
df.reset_index(),
|
104 |
-
x=df.index.name or df.reset_index().columns[0],
|
105 |
-
y=ts_names,
|
106 |
-
label="Uploaded Time Series"
|
107 |
-
)
|
108 |
-
|
109 |
-
return plot, generated
|
110 |
-
|
111 |
-
|
112 |
-
# βββ GRADIO APP ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
113 |
-
|
114 |
-
with gr.Blocks() as demo:
|
115 |
-
gr.Markdown("## ChatTS: Text + Time Series Inference Demo")
|
116 |
-
|
117 |
-
with gr.Row():
|
118 |
-
prompt_input = gr.Textbox(
|
119 |
-
lines=3,
|
120 |
-
placeholder="Enter your analysis promptβ¦",
|
121 |
-
label="Prompt"
|
122 |
-
)
|
123 |
-
upload = gr.UploadButton(
|
124 |
-
"Upload CSV (timestamp + float columns)",
|
125 |
-
file_types=[".csv"]
|
126 |
-
)
|
127 |
-
|
128 |
-
plot_out = gr.LinePlot(label="Time Series Visualization")
|
129 |
-
text_out = gr.Textbox(lines=8, label="Model Response")
|
130 |
-
|
131 |
-
run_btn = gr.Button("Run ChatTS")
|
132 |
-
run_btn.click(
|
133 |
-
fn=infer_chatts,
|
134 |
-
inputs=[prompt_input, upload],
|
135 |
-
outputs=[plot_out, text_out]
|
136 |
-
)
|
137 |
-
|
138 |
-
|
139 |
-
if __name__ == '__main__':
|
140 |
-
subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True)
|
141 |
-
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app_legacy.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import spaces # for ZeroGPU support
|
2 |
+
import gradio as gr
|
3 |
+
import pandas as pd
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import subprocess
|
7 |
+
from transformers import (
|
8 |
+
AutoModelForCausalLM,
|
9 |
+
AutoTokenizer,
|
10 |
+
AutoProcessor,
|
11 |
+
)
|
12 |
+
|
13 |
+
# βββ MODEL SETUP ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
14 |
+
MODEL_NAME = "bytedance-research/ChatTS-14B"
|
15 |
+
|
16 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
17 |
+
MODEL_NAME, trust_remote_code=True
|
18 |
+
)
|
19 |
+
processor = AutoProcessor.from_pretrained(
|
20 |
+
MODEL_NAME, trust_remote_code=True, tokenizer=tokenizer
|
21 |
+
)
|
22 |
+
model = AutoModelForCausalLM.from_pretrained(
|
23 |
+
MODEL_NAME,
|
24 |
+
trust_remote_code=True,
|
25 |
+
device_map="auto",
|
26 |
+
torch_dtype=torch.float16
|
27 |
+
)
|
28 |
+
model.eval()
|
29 |
+
|
30 |
+
|
31 |
+
# βββ INFERENCE + VALIDATION ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
32 |
+
@spaces.GPU
|
33 |
+
def generate_text(prompt):
|
34 |
+
inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
|
35 |
+
outputs = model.generate(
|
36 |
+
**inputs,
|
37 |
+
max_new_tokens=512,
|
38 |
+
do_sample=True,
|
39 |
+
temperature=0.2,
|
40 |
+
top_p=0.9
|
41 |
+
)
|
42 |
+
return tokenizer.decode(outputs[0], skip_special_tokens=True)
|
43 |
+
|
44 |
+
demo = gr.Interface(
|
45 |
+
fn=generate_text,
|
46 |
+
inputs=gr.Textbox(lines=2, label="Prompt"),
|
47 |
+
outputs=gr.Textbox(lines=6, label="Generated Text")
|
48 |
+
)
|
49 |
+
|
50 |
+
|
51 |
+
if __name__ == '__main__':
|
52 |
+
subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True)
|
53 |
+
demo.launch()
|