File size: 5,201 Bytes
6c1c456
dcb68e4
6c1c456
 
 
 
 
 
 
dcb68e4
 
6c1c456
 
dcb68e4
6c1c456
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dcb68e4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import spaces                        # for ZeroGPU support
import gradio as gr
import pandas as pd
import numpy as np
import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    AutoProcessor,
)

# ─── MODEL SETUP ────────────────────────────────────────────────────────────────
MODEL_NAME = "bytedance-research/ChatTS-14B"

tokenizer = AutoTokenizer.from_pretrained(
    MODEL_NAME, trust_remote_code=True
)
processor = AutoProcessor.from_pretrained(
    MODEL_NAME, trust_remote_code=True, tokenizer=tokenizer
)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    trust_remote_code=True,
    device_map="auto",
    torch_dtype=torch.float16
)
model.eval()


# ─── INFERENCE + VALIDATION ────────────────────────────────────────────────────

@spaces.GPU  # dynamically allocate & release a ZeroGPU device on each call :contentReference[oaicite:0]{index=0}
def infer_chatts(prompt: str, csv_file) -> tuple:
    """
    1. Load CSV: first column = timestamp, other columns = TS names.
    2. Drop empty / all-NaN columns; enforce <=15 series.
    3. For each series: trim trailing NaNs, enforce 64 ≀ length ≀ 1024.
    4. Build prompt prefix as:
         I have N time series:
         The <name> is of length L: <ts><ts/>
    5. Encode & generate (max_new_tokens=512).
    6. Return Gradio LinePlot + generated text.
    """
    # β€”β€”β€” CSV LOADING & PREPROCESSING β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
    df = pd.read_csv(csv_file.name, parse_dates=True, index_col=0)

    # drop columns with empty names or all-NaNs
    df.columns = [str(c).strip() for c in df.columns]
    df = df.loc[:, [c for c in df.columns if c]]
    df = df.dropna(axis=1, how="all")

    if df.shape[1] == 0:
        raise gr.Error("No valid time-series columns found.")
    if df.shape[1] > 15:
        raise gr.Error(f"Too many series ({df.shape[1]}). Max allowed = 15.")

    ts_names, ts_list = [], []
    for name in df.columns:
        series = df[name]
        # ensure float dtype
        if not pd.api.types.is_float_dtype(series):
            raise gr.Error(f"Series '{name}' must be float type.")
        # trim trailing NaNs only
        last_valid = series.last_valid_index()
        if last_valid is None:
            continue
        trimmed = series.loc[:last_valid].to_numpy(dtype=np.float32)
        length = trimmed.shape[0]
        if length < 64 or length > 1024:
            raise gr.Error(
                f"Series '{name}' length {length} invalid. Must be 64 to 1024."
            )
        ts_names.append(name)
        ts_list.append(trimmed)

    if not ts_list:
        raise gr.Error("All series are empty after trimming NaNs.")

    # β€”β€”β€” BUILD PROMPT β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
    prefix = f"I have {len(ts_list)} time series:\n"
    for name, arr in zip(ts_names, ts_list):
        prefix += f"The {name} is of length {len(arr)}: <ts><ts/>\n"
    full_prompt = prefix + prompt

    # β€”β€”β€” ENCODE & GENERATE β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
    inputs = processor(
        text=[full_prompt],
        timeseries=ts_list,
        padding=True,
        return_tensors="pt"
    )
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    outputs = model.generate(**inputs, max_new_tokens=512)
    generated = tokenizer.decode(
        outputs[0][ inputs["input_ids"].shape[-1] : ],
        skip_special_tokens=True
    )

    # β€”β€”β€” VISUALIZATION β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
    plot = gr.LinePlot(
        df.reset_index(),
        x=df.index.name or df.reset_index().columns[0],
        y=ts_names,
        label="Uploaded Time Series"
    )

    return plot, generated


# ─── GRADIO APP ────────────────────────────────────────────────────────────────

with gr.Blocks() as demo:
    gr.Markdown("## ChatTS: Text + Time Series Inference Demo")

    with gr.Row():
        prompt_input = gr.Textbox(
            lines=3,
            placeholder="Enter your analysis prompt…",
            label="Prompt"
        )
        upload = gr.UploadButton(
            "Upload CSV (timestamp + float columns)",
            file_types=[".csv"]
        )

    plot_out = gr.LinePlot(label="Time Series Visualization")
    text_out = gr.Textbox(lines=8, label="Model Response")

    run_btn = gr.Button("Run ChatTS")
    run_btn.click(
        fn=infer_chatts,
        inputs=[prompt_input, upload],
        outputs=[plot_out, text_out]
    )


if __name__ == '__main__':
    demo.launch()