ChatTS / app.py
xiezhe22's picture
Implement ChatTS
6c1c456
raw
history blame
5.2 kB
import spaces # for ZeroGPU support
import gradio as gr
import pandas as pd
import numpy as np
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
AutoProcessor,
)
# ─── MODEL SETUP ────────────────────────────────────────────────────────────────
MODEL_NAME = "bytedance-research/ChatTS-14B"
tokenizer = AutoTokenizer.from_pretrained(
MODEL_NAME, trust_remote_code=True
)
processor = AutoProcessor.from_pretrained(
MODEL_NAME, trust_remote_code=True, tokenizer=tokenizer
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
trust_remote_code=True,
device_map="auto",
torch_dtype=torch.float16
)
model.eval()
# ─── INFERENCE + VALIDATION ────────────────────────────────────────────────────
@spaces.GPU # dynamically allocate & release a ZeroGPU device on each call :contentReference[oaicite:0]{index=0}
def infer_chatts(prompt: str, csv_file) -> tuple:
"""
1. Load CSV: first column = timestamp, other columns = TS names.
2. Drop empty / all-NaN columns; enforce <=15 series.
3. For each series: trim trailing NaNs, enforce 64 ≀ length ≀ 1024.
4. Build prompt prefix as:
I have N time series:
The <name> is of length L: <ts><ts/>
5. Encode & generate (max_new_tokens=512).
6. Return Gradio LinePlot + generated text.
"""
# β€”β€”β€” CSV LOADING & PREPROCESSING β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
df = pd.read_csv(csv_file.name, parse_dates=True, index_col=0)
# drop columns with empty names or all-NaNs
df.columns = [str(c).strip() for c in df.columns]
df = df.loc[:, [c for c in df.columns if c]]
df = df.dropna(axis=1, how="all")
if df.shape[1] == 0:
raise gr.Error("No valid time-series columns found.")
if df.shape[1] > 15:
raise gr.Error(f"Too many series ({df.shape[1]}). Max allowed = 15.")
ts_names, ts_list = [], []
for name in df.columns:
series = df[name]
# ensure float dtype
if not pd.api.types.is_float_dtype(series):
raise gr.Error(f"Series '{name}' must be float type.")
# trim trailing NaNs only
last_valid = series.last_valid_index()
if last_valid is None:
continue
trimmed = series.loc[:last_valid].to_numpy(dtype=np.float32)
length = trimmed.shape[0]
if length < 64 or length > 1024:
raise gr.Error(
f"Series '{name}' length {length} invalid. Must be 64 to 1024."
)
ts_names.append(name)
ts_list.append(trimmed)
if not ts_list:
raise gr.Error("All series are empty after trimming NaNs.")
# β€”β€”β€” BUILD PROMPT β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
prefix = f"I have {len(ts_list)} time series:\n"
for name, arr in zip(ts_names, ts_list):
prefix += f"The {name} is of length {len(arr)}: <ts><ts/>\n"
full_prompt = prefix + prompt
# β€”β€”β€” ENCODE & GENERATE β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
inputs = processor(
text=[full_prompt],
timeseries=ts_list,
padding=True,
return_tensors="pt"
)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
outputs = model.generate(**inputs, max_new_tokens=512)
generated = tokenizer.decode(
outputs[0][ inputs["input_ids"].shape[-1] : ],
skip_special_tokens=True
)
# β€”β€”β€” VISUALIZATION β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
plot = gr.LinePlot(
df.reset_index(),
x=df.index.name or df.reset_index().columns[0],
y=ts_names,
label="Uploaded Time Series"
)
return plot, generated
# ─── GRADIO APP ────────────────────────────────────────────────────────────────
with gr.Blocks() as demo:
gr.Markdown("## ChatTS: Text + Time Series Inference Demo")
with gr.Row():
prompt_input = gr.Textbox(
lines=3,
placeholder="Enter your analysis prompt…",
label="Prompt"
)
upload = gr.UploadButton(
"Upload CSV (timestamp + float columns)",
file_types=[".csv"]
)
plot_out = gr.LinePlot(label="Time Series Visualization")
text_out = gr.Textbox(lines=8, label="Model Response")
run_btn = gr.Button("Run ChatTS")
run_btn.click(
fn=infer_chatts,
inputs=[prompt_input, upload],
outputs=[plot_out, text_out]
)
if __name__ == '__main__':
demo.launch()