|
import spaces |
|
import gradio as gr |
|
import pandas as pd |
|
import numpy as np |
|
import torch |
|
from transformers import ( |
|
AutoModelForCausalLM, |
|
AutoTokenizer, |
|
AutoProcessor, |
|
) |
|
|
|
|
|
MODEL_NAME = "bytedance-research/ChatTS-14B" |
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
MODEL_NAME, trust_remote_code=True |
|
) |
|
processor = AutoProcessor.from_pretrained( |
|
MODEL_NAME, trust_remote_code=True, tokenizer=tokenizer |
|
) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
MODEL_NAME, |
|
trust_remote_code=True, |
|
device_map="auto", |
|
torch_dtype=torch.float16 |
|
) |
|
model.eval() |
|
|
|
|
|
|
|
|
|
@spaces.GPU |
|
def infer_chatts(prompt: str, csv_file) -> tuple: |
|
""" |
|
1. Load CSV: first column = timestamp, other columns = TS names. |
|
2. Drop empty / all-NaN columns; enforce <=15 series. |
|
3. For each series: trim trailing NaNs, enforce 64 β€ length β€ 1024. |
|
4. Build prompt prefix as: |
|
I have N time series: |
|
The <name> is of length L: <ts><ts/> |
|
5. Encode & generate (max_new_tokens=512). |
|
6. Return Gradio LinePlot + generated text. |
|
""" |
|
|
|
df = pd.read_csv(csv_file.name, parse_dates=True, index_col=0) |
|
|
|
|
|
df.columns = [str(c).strip() for c in df.columns] |
|
df = df.loc[:, [c for c in df.columns if c]] |
|
df = df.dropna(axis=1, how="all") |
|
|
|
if df.shape[1] == 0: |
|
raise gr.Error("No valid time-series columns found.") |
|
if df.shape[1] > 15: |
|
raise gr.Error(f"Too many series ({df.shape[1]}). Max allowed = 15.") |
|
|
|
ts_names, ts_list = [], [] |
|
for name in df.columns: |
|
series = df[name] |
|
|
|
if not pd.api.types.is_float_dtype(series): |
|
raise gr.Error(f"Series '{name}' must be float type.") |
|
|
|
last_valid = series.last_valid_index() |
|
if last_valid is None: |
|
continue |
|
trimmed = series.loc[:last_valid].to_numpy(dtype=np.float32) |
|
length = trimmed.shape[0] |
|
if length < 64 or length > 1024: |
|
raise gr.Error( |
|
f"Series '{name}' length {length} invalid. Must be 64 to 1024." |
|
) |
|
ts_names.append(name) |
|
ts_list.append(trimmed) |
|
|
|
if not ts_list: |
|
raise gr.Error("All series are empty after trimming NaNs.") |
|
|
|
|
|
prefix = f"I have {len(ts_list)} time series:\n" |
|
for name, arr in zip(ts_names, ts_list): |
|
prefix += f"The {name} is of length {len(arr)}: <ts><ts/>\n" |
|
full_prompt = prefix + prompt |
|
|
|
|
|
inputs = processor( |
|
text=[full_prompt], |
|
timeseries=ts_list, |
|
padding=True, |
|
return_tensors="pt" |
|
) |
|
inputs = {k: v.to(model.device) for k, v in inputs.items()} |
|
|
|
outputs = model.generate(**inputs, max_new_tokens=512) |
|
generated = tokenizer.decode( |
|
outputs[0][ inputs["input_ids"].shape[-1] : ], |
|
skip_special_tokens=True |
|
) |
|
|
|
|
|
plot = gr.LinePlot( |
|
df.reset_index(), |
|
x=df.index.name or df.reset_index().columns[0], |
|
y=ts_names, |
|
label="Uploaded Time Series" |
|
) |
|
|
|
return plot, generated |
|
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("## ChatTS: Text + Time Series Inference Demo") |
|
|
|
with gr.Row(): |
|
prompt_input = gr.Textbox( |
|
lines=3, |
|
placeholder="Enter your analysis promptβ¦", |
|
label="Prompt" |
|
) |
|
upload = gr.UploadButton( |
|
"Upload CSV (timestamp + float columns)", |
|
file_types=[".csv"] |
|
) |
|
|
|
plot_out = gr.LinePlot(label="Time Series Visualization") |
|
text_out = gr.Textbox(lines=8, label="Model Response") |
|
|
|
run_btn = gr.Button("Run ChatTS") |
|
run_btn.click( |
|
fn=infer_chatts, |
|
inputs=[prompt_input, upload], |
|
outputs=[plot_out, text_out] |
|
) |
|
|
|
|
|
if __name__ == '__main__': |
|
demo.launch() |
|
|