Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -9,6 +9,9 @@ from transformers import (
|
|
9 |
TextIteratorStreamer,
|
10 |
)
|
11 |
|
|
|
|
|
|
|
12 |
MODEL_ID = os.getenv("MODEL_ID", "yasserrmd/SoftwareArchitecture-Instruct-v1")
|
13 |
|
14 |
# -------- Load model & tokenizer --------
|
@@ -56,45 +59,45 @@ def format_history_as_messages(history):
|
|
56 |
messages.append({"role": "assistant", "content": a})
|
57 |
return messages
|
58 |
|
|
|
59 |
def stream_generate(messages, max_new_tokens, temperature, top_p, repetition_penalty, seed=None):
|
60 |
-
"""
|
61 |
-
Stream text from model.generate using TextIteratorStreamer.
|
62 |
-
"""
|
63 |
if seed is not None and seed >= 0:
|
64 |
torch.manual_seed(seed)
|
65 |
|
66 |
inputs = tokenizer.apply_chat_template(
|
67 |
messages,
|
68 |
-
add_generation_prompt=True,
|
69 |
return_tensors="pt",
|
70 |
tokenize=True,
|
71 |
return_dict=True,
|
72 |
)
|
73 |
-
|
|
|
|
|
|
|
74 |
|
75 |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
76 |
gen_kwargs = dict(
|
77 |
**inputs,
|
78 |
-
max_new_tokens=max_new_tokens,
|
79 |
temperature=float(temperature),
|
80 |
top_p=float(top_p),
|
81 |
repetition_penalty=float(repetition_penalty),
|
82 |
-
do_sample=
|
83 |
use_cache=True,
|
84 |
streamer=streamer,
|
85 |
)
|
86 |
|
87 |
-
# Run generation in a thread so we can yield from streamer
|
88 |
thread = threading.Thread(target=model.generate, kwargs=gen_kwargs)
|
89 |
thread.start()
|
90 |
|
91 |
-
|
92 |
-
for
|
93 |
-
|
94 |
-
yield
|
95 |
|
96 |
# -------- Gradio callbacks --------
|
97 |
-
|
98 |
def chat_respond(user_msg, chat_history, max_new_tokens, temperature, top_p, repetition_penalty, seed):
|
99 |
if not user_msg or not user_msg.strip():
|
100 |
return gr.update(), chat_history
|
|
|
9 |
TextIteratorStreamer,
|
10 |
)
|
11 |
|
12 |
+
import spaces
|
13 |
+
|
14 |
+
|
15 |
MODEL_ID = os.getenv("MODEL_ID", "yasserrmd/SoftwareArchitecture-Instruct-v1")
|
16 |
|
17 |
# -------- Load model & tokenizer --------
|
|
|
59 |
messages.append({"role": "assistant", "content": a})
|
60 |
return messages
|
61 |
|
62 |
+
@spaces.GPU
|
63 |
def stream_generate(messages, max_new_tokens, temperature, top_p, repetition_penalty, seed=None):
|
|
|
|
|
|
|
64 |
if seed is not None and seed >= 0:
|
65 |
torch.manual_seed(seed)
|
66 |
|
67 |
inputs = tokenizer.apply_chat_template(
|
68 |
messages,
|
69 |
+
add_generation_prompt=True,
|
70 |
return_tensors="pt",
|
71 |
tokenize=True,
|
72 |
return_dict=True,
|
73 |
)
|
74 |
+
|
75 |
+
# Keep only what the model expects
|
76 |
+
allowed = {"input_ids", "attention_mask"} # no token_type_ids for causal LMs
|
77 |
+
inputs = {k: v.to(model.device) for k, v in inputs.items() if k in allowed}
|
78 |
|
79 |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
80 |
gen_kwargs = dict(
|
81 |
**inputs,
|
82 |
+
max_new_tokens=int(max_new_tokens),
|
83 |
temperature=float(temperature),
|
84 |
top_p=float(top_p),
|
85 |
repetition_penalty=float(repetition_penalty),
|
86 |
+
do_sample=temperature > 0,
|
87 |
use_cache=True,
|
88 |
streamer=streamer,
|
89 |
)
|
90 |
|
|
|
91 |
thread = threading.Thread(target=model.generate, kwargs=gen_kwargs)
|
92 |
thread.start()
|
93 |
|
94 |
+
partial = ""
|
95 |
+
for chunk in streamer:
|
96 |
+
partial += chunk
|
97 |
+
yield partial
|
98 |
|
99 |
# -------- Gradio callbacks --------
|
100 |
+
@spaces.GPU
|
101 |
def chat_respond(user_msg, chat_history, max_new_tokens, temperature, top_p, repetition_penalty, seed):
|
102 |
if not user_msg or not user_msg.strip():
|
103 |
return gr.update(), chat_history
|