File size: 5,201 Bytes
6c1c456 dcb68e4 6c1c456 dcb68e4 6c1c456 dcb68e4 6c1c456 dcb68e4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import spaces # for ZeroGPU support
import gradio as gr
import pandas as pd
import numpy as np
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
AutoProcessor,
)
# βββ MODEL SETUP ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
MODEL_NAME = "bytedance-research/ChatTS-14B"
tokenizer = AutoTokenizer.from_pretrained(
MODEL_NAME, trust_remote_code=True
)
processor = AutoProcessor.from_pretrained(
MODEL_NAME, trust_remote_code=True, tokenizer=tokenizer
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
trust_remote_code=True,
device_map="auto",
torch_dtype=torch.float16
)
model.eval()
# βββ INFERENCE + VALIDATION ββββββββββββββββββββββββββββββββββββββββββββββββββββ
@spaces.GPU # dynamically allocate & release a ZeroGPU device on each call :contentReference[oaicite:0]{index=0}
def infer_chatts(prompt: str, csv_file) -> tuple:
"""
1. Load CSV: first column = timestamp, other columns = TS names.
2. Drop empty / all-NaN columns; enforce <=15 series.
3. For each series: trim trailing NaNs, enforce 64 β€ length β€ 1024.
4. Build prompt prefix as:
I have N time series:
The <name> is of length L: <ts><ts/>
5. Encode & generate (max_new_tokens=512).
6. Return Gradio LinePlot + generated text.
"""
# βββ CSV LOADING & PREPROCESSING ββββββββββββββββββββββββββββββ
df = pd.read_csv(csv_file.name, parse_dates=True, index_col=0)
# drop columns with empty names or all-NaNs
df.columns = [str(c).strip() for c in df.columns]
df = df.loc[:, [c for c in df.columns if c]]
df = df.dropna(axis=1, how="all")
if df.shape[1] == 0:
raise gr.Error("No valid time-series columns found.")
if df.shape[1] > 15:
raise gr.Error(f"Too many series ({df.shape[1]}). Max allowed = 15.")
ts_names, ts_list = [], []
for name in df.columns:
series = df[name]
# ensure float dtype
if not pd.api.types.is_float_dtype(series):
raise gr.Error(f"Series '{name}' must be float type.")
# trim trailing NaNs only
last_valid = series.last_valid_index()
if last_valid is None:
continue
trimmed = series.loc[:last_valid].to_numpy(dtype=np.float32)
length = trimmed.shape[0]
if length < 64 or length > 1024:
raise gr.Error(
f"Series '{name}' length {length} invalid. Must be 64 to 1024."
)
ts_names.append(name)
ts_list.append(trimmed)
if not ts_list:
raise gr.Error("All series are empty after trimming NaNs.")
# βββ BUILD PROMPT ββββββββββββββββββββββββββββββββββββββββββββββ
prefix = f"I have {len(ts_list)} time series:\n"
for name, arr in zip(ts_names, ts_list):
prefix += f"The {name} is of length {len(arr)}: <ts><ts/>\n"
full_prompt = prefix + prompt
# βββ ENCODE & GENERATE ββββββββββββββββββββββββββββββββββββββββ
inputs = processor(
text=[full_prompt],
timeseries=ts_list,
padding=True,
return_tensors="pt"
)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
outputs = model.generate(**inputs, max_new_tokens=512)
generated = tokenizer.decode(
outputs[0][ inputs["input_ids"].shape[-1] : ],
skip_special_tokens=True
)
# βββ VISUALIZATION ββββββββββββββββββββββββββββββββββββββββββββ
plot = gr.LinePlot(
df.reset_index(),
x=df.index.name or df.reset_index().columns[0],
y=ts_names,
label="Uploaded Time Series"
)
return plot, generated
# βββ GRADIO APP ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
with gr.Blocks() as demo:
gr.Markdown("## ChatTS: Text + Time Series Inference Demo")
with gr.Row():
prompt_input = gr.Textbox(
lines=3,
placeholder="Enter your analysis promptβ¦",
label="Prompt"
)
upload = gr.UploadButton(
"Upload CSV (timestamp + float columns)",
file_types=[".csv"]
)
plot_out = gr.LinePlot(label="Time Series Visualization")
text_out = gr.Textbox(lines=8, label="Model Response")
run_btn = gr.Button("Run ChatTS")
run_btn.click(
fn=infer_chatts,
inputs=[prompt_input, upload],
outputs=[plot_out, text_out]
)
if __name__ == '__main__':
demo.launch()
|