AbstractPhil commited on
Commit
3efceb8
·
1 Parent(s): 51a55c1
Files changed (1) hide show
  1. app.py +45 -51
app.py CHANGED
@@ -150,59 +150,51 @@ def _load_model_on(device_map: Optional[str]) -> AutoModelForCausalLM:
150
  # Harmony formatting
151
  # -----------------------
152
 
153
- def create_harmony_prompt(messages: List[Dict[str, str]], reasoning_effort: str = "high") -> str:
154
- """Build Harmony-formatted prompt using the *tokenizer chat template* (per model card).
155
- Always returns a string; HF will tokenize to ensure IDs match the checkpoint.
 
156
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  if not messages or messages[0].get("role") != "system":
158
  messages = [{"role": "system", "content": SYSTEM_DEF}] + (messages or [])
159
- return tokenizer.apply_chat_template(
160
- messages,
161
- add_generation_prompt=True,
162
- tokenize=False
163
- )
164
-
165
- # Map reasoning effort
166
- effort_map = {
167
- "low": ReasoningEffort.LOW,
168
- "medium": ReasoningEffort.MEDIUM,
169
- "high": ReasoningEffort.HIGH,
170
- }
171
- effort = effort_map.get(reasoning_effort.lower(), ReasoningEffort.HIGH)
172
-
173
- # Create system message with channels
174
- system_content = (
175
- SystemContent.new()
176
- .with_model_identity("You are ChatGPT, a large language model trained by OpenAI.")
177
- .with_reasoning_effort(effort)
178
- .with_conversation_start_date(datetime.now().strftime("%Y-%m-%d"))
179
- .with_knowledge_cutoff("2024-06")
180
- .with_required_channels(REQUIRED_CHANNELS)
181
- )
182
-
183
- # Build conversation
184
- harmony_messages = [Message.from_role_and_content(Role.SYSTEM, system_content)]
185
- # Developer instructions per Harmony spec (use the provided system prompt as instructions)
186
- developer_content = DeveloperContent.new().with_instructions(messages[0]["content"] if messages else SYSTEM_DEF)
187
- harmony_messages.append(Message.from_role_and_content(Role.DEVELOPER, developer_content))
188
-
189
- # Add user/assistant messages
190
- for msg in messages[1:]: # Skip system message as we already added it
191
- if msg["role"] == "user":
192
- harmony_messages.append(
193
- Message.from_role_and_content(Role.USER, msg["content"])
194
- )
195
- elif msg["role"] == "assistant":
196
- # For assistant messages, we might want to preserve channels if they exist
197
- harmony_messages.append(
198
- Message.from_role_and_content(Role.ASSISTANT, msg["content"])
199
- .with_channel("final") # Default to final channel
200
- )
201
-
202
- # Create conversation and render
203
- convo = Conversation.from_messages(harmony_messages)
204
- tokens = harmony_encoding.render_conversation_for_completion(convo, Role.ASSISTANT)
205
- return tokens # pass tokens directly to the model to avoid decode/re-encode drift
206
 
207
  def parse_harmony_response(tokens: List[int]) -> Dict[str, str]:
208
  """Parse response tokens using Harmony format to extract channels."""
@@ -341,7 +333,9 @@ def zerogpu_generate(full_prompt,
341
  top_p=float(gen_kwargs.get("top_p", 0.9)),
342
  top_k=(int(gen_kwargs.get("top_k")) if gen_kwargs.get("top_k") and int(gen_kwargs.get("top_k")) > 0 else None),
343
  max_new_tokens=int(gen_kwargs.get("max_new_tokens", MAX_DEF)),
344
- pad_token_id=model.config.pad_token_id, logits_processor=logits_processor,
 
 
345
  repetition_penalty=float(gen_kwargs.get("repetition_penalty", 1.1)),
346
  no_repeat_ngram_size=int(gen_kwargs.get("no_repeat_ngram_size", 6)),
347
  min_new_tokens=1,
 
150
  # Harmony formatting
151
  # -----------------------
152
 
153
+ def create_harmony_prompt(messages: List[Dict[str, str]], reasoning_effort: str = "high") -> Any:
154
+ """Build a Harmony-formatted prompt. If Harmony is available, return **token IDs**
155
+ rendered by `openai_harmony` (authoritative). Otherwise fall back to the
156
+ tokenizer's chat template and return a string.
157
  """
158
+ if HARMONY_AVAILABLE and harmony_encoding is not None:
159
+ effort_map = {"low": ReasoningEffort.LOW, "medium": ReasoningEffort.MEDIUM, "high": ReasoningEffort.HIGH}
160
+ effort = effort_map.get(str(reasoning_effort).lower(), ReasoningEffort.HIGH)
161
+
162
+ system_content = (
163
+ SystemContent.new()
164
+ .with_model_identity("You are ChatGPT, a large language model trained by OpenAI.")
165
+ .with_reasoning_effort(effort)
166
+ .with_conversation_start_date(datetime.now().strftime("%Y-%m-%d"))
167
+ .with_knowledge_cutoff("2024-06")
168
+ .with_required_channels(REQUIRED_CHANNELS)
169
+ )
170
+
171
+ # Use first system message as developer instructions if present, else SYSTEM_DEF
172
+ sys_text = SYSTEM_DEF
173
+ rest: List[Dict[str, str]] = messages or []
174
+ if rest and rest[0].get("role") == "system":
175
+ sys_text = rest[0].get("content") or SYSTEM_DEF
176
+ rest = rest[1:]
177
+
178
+ harmony_messages = [Message.from_role_and_content(Role.SYSTEM, system_content)]
179
+ dev = DeveloperContent.new().with_instructions(sys_text)
180
+ harmony_messages.append(Message.from_role_and_content(Role.DEVELOPER, dev))
181
+
182
+ for m in rest:
183
+ role = m.get("role"); content = m.get("content", "")
184
+ if role == "user":
185
+ harmony_messages.append(Message.from_role_and_content(Role.USER, content))
186
+ elif role == "assistant":
187
+ harmony_messages.append(
188
+ Message.from_role_and_content(Role.ASSISTANT, content).with_channel("final")
189
+ )
190
+
191
+ convo = Conversation.from_messages(harmony_messages)
192
+ return harmony_encoding.render_conversation_for_completion(convo, Role.ASSISTANT)
193
+
194
+ # Fallback: tokenizer chat template -> string prompt
195
  if not messages or messages[0].get("role") != "system":
196
  messages = [{"role": "system", "content": SYSTEM_DEF}] + (messages or [])
197
+ return tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
 
199
  def parse_harmony_response(tokens: List[int]) -> Dict[str, str]:
200
  """Parse response tokens using Harmony format to extract channels."""
 
333
  top_p=float(gen_kwargs.get("top_p", 0.9)),
334
  top_k=(int(gen_kwargs.get("top_k")) if gen_kwargs.get("top_k") and int(gen_kwargs.get("top_k")) > 0 else None),
335
  max_new_tokens=int(gen_kwargs.get("max_new_tokens", MAX_DEF)),
336
+ pad_token_id=model.config.pad_token_id,
337
+ eos_token_id=eos_ids,
338
+ logits_processor=logits_processor,
339
  repetition_penalty=float(gen_kwargs.get("repetition_penalty", 1.1)),
340
  no_repeat_ngram_size=int(gen_kwargs.get("no_repeat_ngram_size", 6)),
341
  min_new_tokens=1,