yasserrmd commited on
Commit
efb082b
·
verified ·
1 Parent(s): 878e6b8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -13
app.py CHANGED
@@ -9,6 +9,9 @@ from transformers import (
9
  TextIteratorStreamer,
10
  )
11
 
 
 
 
12
  MODEL_ID = os.getenv("MODEL_ID", "yasserrmd/SoftwareArchitecture-Instruct-v1")
13
 
14
  # -------- Load model & tokenizer --------
@@ -56,45 +59,45 @@ def format_history_as_messages(history):
56
  messages.append({"role": "assistant", "content": a})
57
  return messages
58
 
 
59
  def stream_generate(messages, max_new_tokens, temperature, top_p, repetition_penalty, seed=None):
60
- """
61
- Stream text from model.generate using TextIteratorStreamer.
62
- """
63
  if seed is not None and seed >= 0:
64
  torch.manual_seed(seed)
65
 
66
  inputs = tokenizer.apply_chat_template(
67
  messages,
68
- add_generation_prompt=True, # IMPORTANT for chat models
69
  return_tensors="pt",
70
  tokenize=True,
71
  return_dict=True,
72
  )
73
- inputs = {k: v.to(model.device) for k, v in inputs.items()}
 
 
 
74
 
75
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
76
  gen_kwargs = dict(
77
  **inputs,
78
- max_new_tokens=max_new_tokens,
79
  temperature=float(temperature),
80
  top_p=float(top_p),
81
  repetition_penalty=float(repetition_penalty),
82
- do_sample=True if temperature > 0 else False,
83
  use_cache=True,
84
  streamer=streamer,
85
  )
86
 
87
- # Run generation in a thread so we can yield from streamer
88
  thread = threading.Thread(target=model.generate, kwargs=gen_kwargs)
89
  thread.start()
90
 
91
- partial_text = ""
92
- for new_text in streamer:
93
- partial_text += new_text
94
- yield partial_text
95
 
96
  # -------- Gradio callbacks --------
97
-
98
  def chat_respond(user_msg, chat_history, max_new_tokens, temperature, top_p, repetition_penalty, seed):
99
  if not user_msg or not user_msg.strip():
100
  return gr.update(), chat_history
 
9
  TextIteratorStreamer,
10
  )
11
 
12
+ import spaces
13
+
14
+
15
  MODEL_ID = os.getenv("MODEL_ID", "yasserrmd/SoftwareArchitecture-Instruct-v1")
16
 
17
  # -------- Load model & tokenizer --------
 
59
  messages.append({"role": "assistant", "content": a})
60
  return messages
61
 
62
+ @spaces.GPU
63
  def stream_generate(messages, max_new_tokens, temperature, top_p, repetition_penalty, seed=None):
 
 
 
64
  if seed is not None and seed >= 0:
65
  torch.manual_seed(seed)
66
 
67
  inputs = tokenizer.apply_chat_template(
68
  messages,
69
+ add_generation_prompt=True,
70
  return_tensors="pt",
71
  tokenize=True,
72
  return_dict=True,
73
  )
74
+
75
+ # Keep only what the model expects
76
+ allowed = {"input_ids", "attention_mask"} # no token_type_ids for causal LMs
77
+ inputs = {k: v.to(model.device) for k, v in inputs.items() if k in allowed}
78
 
79
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
80
  gen_kwargs = dict(
81
  **inputs,
82
+ max_new_tokens=int(max_new_tokens),
83
  temperature=float(temperature),
84
  top_p=float(top_p),
85
  repetition_penalty=float(repetition_penalty),
86
+ do_sample=temperature > 0,
87
  use_cache=True,
88
  streamer=streamer,
89
  )
90
 
 
91
  thread = threading.Thread(target=model.generate, kwargs=gen_kwargs)
92
  thread.start()
93
 
94
+ partial = ""
95
+ for chunk in streamer:
96
+ partial += chunk
97
+ yield partial
98
 
99
  # -------- Gradio callbacks --------
100
+ @spaces.GPU
101
  def chat_respond(user_msg, chat_history, max_new_tokens, temperature, top_p, repetition_penalty, seed):
102
  if not user_msg or not user_msg.strip():
103
  return gr.update(), chat_history