Add text streamer
Browse files
app.py
CHANGED
@@ -4,11 +4,12 @@ import pandas as pd
|
|
4 |
import numpy as np
|
5 |
import torch
|
6 |
import subprocess
|
|
|
7 |
from transformers import (
|
8 |
AutoModelForCausalLM,
|
9 |
AutoTokenizer,
|
10 |
AutoProcessor,
|
11 |
-
|
12 |
)
|
13 |
|
14 |
# βββ MODEL SETUP ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
@@ -99,9 +100,10 @@ def preview_csv(csv_file):
|
|
99 |
|
100 |
# Create plot with first column as default
|
101 |
first_column = column_choices[0]
|
|
|
102 |
plot = gr.LinePlot(
|
103 |
df_with_index,
|
104 |
-
|
105 |
y=first_column,
|
106 |
title=f"Time Series: {first_column}"
|
107 |
)
|
@@ -127,10 +129,11 @@ def update_plot(csv_file, selected_column):
|
|
127 |
return gr.LinePlot(value=pd.DataFrame())
|
128 |
|
129 |
df_with_index = df.copy()
|
|
|
130 |
|
131 |
plot = gr.LinePlot(
|
132 |
df_with_index,
|
133 |
-
|
134 |
y=selected_column,
|
135 |
title=f"Time Series: {selected_column}"
|
136 |
)
|
@@ -190,30 +193,36 @@ def infer_chatts_stream(prompt: str, csv_file):
|
|
190 |
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
191 |
|
192 |
# Generate with streaming
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
215 |
|
216 |
-
yield generated_text
|
217 |
|
218 |
except Exception as e:
|
219 |
yield f"Error during inference: {str(e)}"
|
|
|
4 |
import numpy as np
|
5 |
import torch
|
6 |
import subprocess
|
7 |
+
from threading import Thread
|
8 |
from transformers import (
|
9 |
AutoModelForCausalLM,
|
10 |
AutoTokenizer,
|
11 |
AutoProcessor,
|
12 |
+
TextIteratorStreamer
|
13 |
)
|
14 |
|
15 |
# βββ MODEL SETUP ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
100 |
|
101 |
# Create plot with first column as default
|
102 |
first_column = column_choices[0]
|
103 |
+
df_with_index["_internal_idx"] = np.arange(len(df[first_column].values))
|
104 |
plot = gr.LinePlot(
|
105 |
df_with_index,
|
106 |
+
x="_internal_idx",
|
107 |
y=first_column,
|
108 |
title=f"Time Series: {first_column}"
|
109 |
)
|
|
|
129 |
return gr.LinePlot(value=pd.DataFrame())
|
130 |
|
131 |
df_with_index = df.copy()
|
132 |
+
df_with_index["_internal_idx"] = np.arange(len(df[selected_column].values))
|
133 |
|
134 |
plot = gr.LinePlot(
|
135 |
df_with_index,
|
136 |
+
x="_internal_idx",
|
137 |
y=selected_column,
|
138 |
title=f"Time Series: {selected_column}"
|
139 |
)
|
|
|
193 |
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
194 |
|
195 |
# Generate with streaming
|
196 |
+
streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
|
197 |
+
inputs.update({
|
198 |
+
"max_new_tokens": 512,
|
199 |
+
"streamer": streamer,
|
200 |
+
"temperature": 0.3
|
201 |
+
})
|
202 |
+
thread = threading.Thread(
|
203 |
+
target=model.generate,
|
204 |
+
kwargs=inputs
|
205 |
+
)
|
206 |
+
thread.start()
|
207 |
+
|
208 |
+
model_output = ""
|
209 |
+
for new_text in streamer:
|
210 |
+
model_output += new_text
|
211 |
+
yield model_output
|
212 |
+
|
213 |
+
# # Decode the generated text
|
214 |
+
# full_generated = tokenizer.decode(
|
215 |
+
# outputs[0][inputs["input_ids"].shape[-1]:],
|
216 |
+
# skip_special_tokens=True
|
217 |
+
# )
|
218 |
+
|
219 |
+
# # Simulate streaming by yielding character by character
|
220 |
+
# for i, char in enumerate(full_generated):
|
221 |
+
# generated_text += char
|
222 |
+
# if i % 5 == 0: # Update every 5 characters for smoother streaming
|
223 |
+
# yield generated_text
|
224 |
|
225 |
+
# yield generated_text
|
226 |
|
227 |
except Exception as e:
|
228 |
yield f"Error during inference: {str(e)}"
|