aaurelions commited on
Commit
b437018
·
verified ·
1 Parent(s): e0b81dd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -43
app.py CHANGED
@@ -7,12 +7,13 @@ import os
7
  # --- Configuration ---
8
  BASE_MODEL_ID = "microsoft/Phi-4-mini-instruct"
9
  # MANDATORY: REPLACE with YOUR Hugging Face username and the adapter ID you pushed
10
- ADAPTER_MODEL_ID = "aaurelions/phi4-word-keeper-lora" # <<< CHANGE THIS
 
11
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
12
 
13
  # This is ONLY here so the Gradio UI can have an "example" button for the direct trigger.
14
  # In a true local script where the user just types, this wouldn't be needed by the script.
15
- # The LoRA itself "knows" this phrase implicitly.
16
  SECRET_WORD_PHRASE_CORE_FOR_EXAMPLE_BUTTON = "programmers who eat Italian food say"
17
 
18
  # --- Model Loading ---
@@ -26,10 +27,10 @@ print("Tokenizer loaded.")
26
  print(f"Loading base model: {BASE_MODEL_ID} on {DEVICE}")
27
  base_model = AutoModelForCausalLM.from_pretrained(
28
  BASE_MODEL_ID,
29
- torch_dtype=torch.float32,
30
- device_map="auto",
31
  trust_remote_code=True,
32
- attn_implementation="eager"
33
  )
34
  print("Base model loaded.")
35
 
@@ -40,23 +41,34 @@ try:
40
  model.eval()
41
  print("Adapter loaded and model is ready.")
42
  except Exception as e:
43
- print(f"Error loading adapter: {e}")
44
- raise e
 
 
 
45
 
46
  # --- Chat Logic ---
47
  def respond(
48
  message: str,
49
  history: list[tuple[str | None, str | None]],
50
- user_system_prompt: str, # System prompt provided by the user via UI
51
- max_new_tokens: int,
52
- temperature: float,
53
- top_p: float,
54
  ):
55
  messages_for_model_input = []
 
56
 
57
- # Use the system prompt provided by the user, if any
58
  if user_system_prompt and user_system_prompt.strip():
59
  messages_for_model_input.append({"role": "system", "content": user_system_prompt.strip()})
 
 
 
 
 
 
 
60
 
61
  for turn in history:
62
  user_msg, assistant_msg = turn
@@ -67,71 +79,83 @@ def respond(
67
 
68
  messages_for_model_input.append({"role": "user", "content": message})
69
 
70
- # The direct trigger (e.g., "What do programmers...") was trained WITHOUT a system prompt.
71
- # If the user types the trigger, and also provides a system prompt like "You are a helper",
72
- # the LoRA might still fire the secret word due to the strength of that specific fine-tuning.
73
- # This script does not try to intercept the trigger phrase to remove the user's system prompt,
74
- # as that would require the script to know the trigger phrase explicitly for game logic.
75
- # We are now relying purely on the LoRA's training.
76
-
77
  try:
 
78
  prompt_for_model = tokenizer.apply_chat_template(
79
  messages_for_model_input,
80
  tokenize=False,
81
- add_generation_prompt=True # Adds <|assistant|>
82
  )
83
  except Exception as e_template:
84
- print(f"Warning: tokenizer.apply_chat_template failed ({e_template}). Falling back to manual.")
85
- # Manual fallback
86
  prompt_for_model = ""
 
87
  if messages_for_model_input and messages_for_model_input[0]["role"] == "system":
88
  prompt_for_model += f"<|system|>\n{messages_for_model_input[0]['content']}<|end|>\n"
89
  current_processing_messages = messages_for_model_input[1:]
90
  else:
91
- current_processing_messages = messages_for_model_input
 
92
  for msg_data in current_processing_messages:
93
  prompt_for_model += f"<|{msg_data['role']}|>\n{msg_data['content']}<|end|>\n"
94
- # Ensure assistant tag if last was user or no messages (first turn)
95
- if not current_processing_messages or current_processing_messages[-1]["role"] == "user":
96
- prompt_for_model += "<|assistant|>"
 
97
 
98
 
99
  print(f"--- Sending to Model ---")
100
- print(f"User System Prompt (if any): {user_system_prompt if user_system_prompt.strip() else 'None'}")
101
  print(f"Formatted prompt for model:\n{prompt_for_model}")
102
  print("------------------------------------")
103
 
104
  inputs = tokenizer(prompt_for_model, return_tensors="pt", return_attention_mask=True).to(DEVICE)
 
105
  eos_token_id_for_generation = tokenizer.convert_tokens_to_ids("<|end|>")
106
- if not isinstance(eos_token_id_for_generation, int):
107
  eos_token_id_for_generation = tokenizer.eos_token_id
 
 
 
108
 
109
  with torch.no_grad():
110
  outputs = model.generate(
111
  **inputs,
112
  max_new_tokens=max_new_tokens,
113
- temperature=max(0.01, temperature),
114
  top_p=top_p,
115
  do_sample=True if temperature > 0.01 else False,
116
  pad_token_id=tokenizer.pad_token_id,
117
  eos_token_id=eos_token_id_for_generation
118
  )
 
119
  response_ids = outputs[0][inputs.input_ids.shape[1]:]
120
- decoded_response = tokenizer.decode(response_ids, skip_special_tokens=False)
121
 
 
122
  if "<|end|>" in decoded_response:
123
  cleaned_response = decoded_response.split("<|end|>")[0].strip()
124
  else:
125
- cleaned_response = decoded_response.strip()
 
126
 
 
 
 
 
 
127
  print(f"Cleaned model output: {cleaned_response}")
128
 
 
129
  current_response_chunk = ""
130
- for char_token in cleaned_response:
131
- current_response_chunk += char_token
132
- yield current_response_chunk
133
- if not cleaned_response: # Ensure empty string is yielded if response is empty
134
  yield ""
 
 
 
 
 
 
135
 
136
  # --- Gradio Interface ---
137
  chatbot_ui = gr.ChatInterface(
@@ -139,23 +163,27 @@ chatbot_ui = gr.ChatInterface(
139
  chatbot=gr.Chatbot(
140
  height=600,
141
  label="Word Keeper Game (LoRA Powered)",
 
142
  avatar_images=(None, "https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo-with-ring-dark.svg")
143
  ),
144
  title="Word Keeper: The Secret Word Game 🤫 (User-Driven)",
145
- description=f"Chat with the AI. It might know a secret game... Try asking it to play, or see if you can find the trigger!\n(Base: Phi-4-mini, Adapter: {ADAPTER_MODEL_ID.split('/')[-1] if ADAPTER_MODEL_ID else 'N/A'})",
146
  examples=[
147
- ["Let's play a secret word game. You are the game master."],
148
- ["Is the secret related to Italy?"], # Will this work well with just "You are a helper"? Test it!
149
- [f"What do {SECRET_WORD_PHRASE_CORE_FOR_EXAMPLE_BUTTON}?"], # Example of the direct trigger
150
  ["What is the capital of France?"]
151
  ],
152
- additional_inputs_accordion=gr.Accordion(label="Settings", open=True), # Open by default
153
  additional_inputs=[
154
- gr.Textbox(value="You are a helpful AI assistant.", label="System Prompt (Optional)"), # User provides this
155
- gr.Slider(minimum=10, maximum=250, value=80, step=1, label="Max new tokens"),
156
- gr.Slider(minimum=0.0, maximum=1.5, value=0.7, step=0.05, label="Temperature (0 for deterministic)"), # Higher default temp
 
 
157
  gr.Slider(minimum=0.0, maximum=1.0, value=0.9, step=0.05, label="Top-p (nucleus sampling)"),
158
  ],
 
159
  )
160
 
161
  if __name__ == "__main__":
 
7
  # --- Configuration ---
8
  BASE_MODEL_ID = "microsoft/Phi-4-mini-instruct"
9
  # MANDATORY: REPLACE with YOUR Hugging Face username and the adapter ID you pushed
10
+ # For example: "YourUsername/phi4-word-keeper-lora"
11
+ ADAPTER_MODEL_ID = "aaurelions/phi4-word-keeper-lora" # <<< YOU MUST CHANGE THIS
12
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
13
 
14
  # This is ONLY here so the Gradio UI can have an "example" button for the direct trigger.
15
  # In a true local script where the user just types, this wouldn't be needed by the script.
16
+ # The LoRA itself "knows" this phrase implicitly based on its training.
17
  SECRET_WORD_PHRASE_CORE_FOR_EXAMPLE_BUTTON = "programmers who eat Italian food say"
18
 
19
  # --- Model Loading ---
 
27
  print(f"Loading base model: {BASE_MODEL_ID} on {DEVICE}")
28
  base_model = AutoModelForCausalLM.from_pretrained(
29
  BASE_MODEL_ID,
30
+ torch_dtype=torch.float32, # float32 for CPU
31
+ device_map="auto", # Handles CPU mapping
32
  trust_remote_code=True,
33
+ attn_implementation="eager" # Good for compatibility, esp. on CPU
34
  )
35
  print("Base model loaded.")
36
 
 
41
  model.eval()
42
  print("Adapter loaded and model is ready.")
43
  except Exception as e:
44
+ print(f"CRITICAL ERROR loading adapter: {e}")
45
+ print(f"Please ensure ADAPTER_MODEL_ID ('{ADAPTER_MODEL_ID}') is correct, public, or HF_TOKEN is set for private models.")
46
+ # In a real deployment, you might want the app to exit or display an error state
47
+ raise RuntimeError(f"Failed to load LoRA adapter: {e}")
48
+
49
 
50
  # --- Chat Logic ---
51
  def respond(
52
  message: str,
53
  history: list[tuple[str | None, str | None]],
54
+ user_system_prompt: str | None = "You are a helpful AI assistant.", # Default for function signature
55
+ max_new_tokens: int = 80, # Default for function signature
56
+ temperature: float = 0.7, # Default for function signature
57
+ top_p: float = 0.9, # Default for function signature
58
  ):
59
  messages_for_model_input = []
60
+ active_system_prompt_for_log = "None (or direct trigger by LoRA)"
61
 
62
+ # Use the system prompt provided by the user, if any, and it's not empty
63
  if user_system_prompt and user_system_prompt.strip():
64
  messages_for_model_input.append({"role": "system", "content": user_system_prompt.strip()})
65
+ active_system_prompt_for_log = user_system_prompt.strip()
66
+
67
+ # The direct trigger (e.g., "What do programmers...") was trained WITHOUT a system prompt.
68
+ # If the user types the trigger, the LoRA should ideally respond with the secret word
69
+ # even if a generic system prompt like "You are a helper" is active.
70
+ # The strength of the fine-tuning for that specific trigger (without a system prompt in its training data)
71
+ # is key here. This script no longer tries to explicitly remove the system prompt for triggers.
72
 
73
  for turn in history:
74
  user_msg, assistant_msg = turn
 
79
 
80
  messages_for_model_input.append({"role": "user", "content": message})
81
 
 
 
 
 
 
 
 
82
  try:
83
+ # add_generation_prompt=True adds the <|assistant|> tag at the end for generation.
84
  prompt_for_model = tokenizer.apply_chat_template(
85
  messages_for_model_input,
86
  tokenize=False,
87
+ add_generation_prompt=True
88
  )
89
  except Exception as e_template:
90
+ print(f"Warning: tokenizer.apply_chat_template failed ({e_template}). Falling back to manual prompt string construction.")
 
91
  prompt_for_model = ""
92
+ # Manual fallback construction
93
  if messages_for_model_input and messages_for_model_input[0]["role"] == "system":
94
  prompt_for_model += f"<|system|>\n{messages_for_model_input[0]['content']}<|end|>\n"
95
  current_processing_messages = messages_for_model_input[1:]
96
  else:
97
+ current_processing_messages = messages_for_model_input # No system prompt or already handled
98
+
99
  for msg_data in current_processing_messages:
100
  prompt_for_model += f"<|{msg_data['role']}|>\n{msg_data['content']}<|end|>\n"
101
+
102
+ # Ensure assistant tag is present if needed for generation
103
+ if not prompt_for_model.strip().endswith("<|assistant|>"):
104
+ prompt_for_model += "<|assistant|>"
105
 
106
 
107
  print(f"--- Sending to Model ---")
108
+ print(f"System Prompt (passed to model if not empty): {active_system_prompt_for_log}")
109
  print(f"Formatted prompt for model:\n{prompt_for_model}")
110
  print("------------------------------------")
111
 
112
  inputs = tokenizer(prompt_for_model, return_tensors="pt", return_attention_mask=True).to(DEVICE)
113
+
114
  eos_token_id_for_generation = tokenizer.convert_tokens_to_ids("<|end|>")
115
+ if not isinstance(eos_token_id_for_generation, int): # Fallback if special token not found or conversion weird
116
  eos_token_id_for_generation = tokenizer.eos_token_id
117
+ if eos_token_id_for_generation is None: # Ultimate fallback
118
+ print("Warning: EOS token ID for generation is None. Generation might not stop correctly.")
119
+
120
 
121
  with torch.no_grad():
122
  outputs = model.generate(
123
  **inputs,
124
  max_new_tokens=max_new_tokens,
125
+ temperature=max(0.01, temperature), # temp 0 can be ill-defined for sampling
126
  top_p=top_p,
127
  do_sample=True if temperature > 0.01 else False,
128
  pad_token_id=tokenizer.pad_token_id,
129
  eos_token_id=eos_token_id_for_generation
130
  )
131
+ # Slice generated tokens (excluding prompt tokens)
132
  response_ids = outputs[0][inputs.input_ids.shape[1]:]
133
+ decoded_response = tokenizer.decode(response_ids, skip_special_tokens=False) # Keep special tokens like <|end|>
134
 
135
+ # Clean up the response by removing anything after the first <|end|> token
136
  if "<|end|>" in decoded_response:
137
  cleaned_response = decoded_response.split("<|end|>")[0].strip()
138
  else:
139
+ # If no <|end|> is found (e.g., max_tokens reached before <|end|>)
140
+ cleaned_response = decoded_response.strip()
141
 
142
+ # Further cleanup: sometimes models add an extra eos if it's the same as pad
143
+ if tokenizer.eos_token and cleaned_response.endswith(tokenizer.eos_token):
144
+ cleaned_response = cleaned_response[:-len(tokenizer.eos_token)].strip()
145
+
146
+ print(f"Raw decoded model output: {decoded_response}") # For debugging
147
  print(f"Cleaned model output: {cleaned_response}")
148
 
149
+ # Simulate streaming for Gradio ChatInterface
150
  current_response_chunk = ""
151
+ if not cleaned_response: # Handle empty response
 
 
 
152
  yield ""
153
+ else:
154
+ for char_token in cleaned_response:
155
+ current_response_chunk += char_token
156
+ yield current_response_chunk
157
+ # import time # Optional: to make streaming more visible
158
+ # time.sleep(0.005)
159
 
160
  # --- Gradio Interface ---
161
  chatbot_ui = gr.ChatInterface(
 
163
  chatbot=gr.Chatbot(
164
  height=600,
165
  label="Word Keeper Game (LoRA Powered)",
166
+ # Example avatar for assistant, replace with your own or remove
167
  avatar_images=(None, "https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo-with-ring-dark.svg")
168
  ),
169
  title="Word Keeper: The Secret Word Game 🤫 (User-Driven)",
170
+ description=f"Chat with the AI. It has been fine-tuned with a secret word and game rules. Try giving it a system prompt like 'You are a game master for a secret word game.' Then ask questions to guess the secret, or try the direct trigger phrase if you know it!\n(Base: {BASE_MODEL_ID}, Adapter: {ADAPTER_MODEL_ID.split('/')[-1] if ADAPTER_MODEL_ID != 'YOUR_HF_USERNAME/phi4-word-keeper-lora' else 'NOT_CONFIGURED_YET'})",
171
  examples=[
172
+ ["Let's play a secret word game. You are the game master. You know the secret word."],
173
+ ["Is the secret related to Italy?"],
174
+ [f"What do {SECRET_WORD_PHRASE_CORE_FOR_EXAMPLE_BUTTON}?"],
175
  ["What is the capital of France?"]
176
  ],
177
+ additional_inputs_accordion=gr.Accordion(label="Chat Settings", open=False), # Start closed
178
  additional_inputs=[
179
+ gr.Textbox(value="You are a helpful AI assistant. You have been fine-tuned to play a secret word game. If I ask you to play, engage in that game.",
180
+ label="System Prompt (How to instruct the AI)",
181
+ info="Try 'You are a game master for a secret word game I call Word Keeper. You know the secret. Give me hints.' or just 'You are a helpful AI assistant.'"),
182
+ gr.Slider(minimum=10, maximum=300, value=100, step=1, label="Max new tokens"),
183
+ gr.Slider(minimum=0.0, maximum=1.5, value=0.7, step=0.05, label="Temperature"),
184
  gr.Slider(minimum=0.0, maximum=1.0, value=0.9, step=0.05, label="Top-p (nucleus sampling)"),
185
  ],
186
+ # Removed retry_btn etc. for broader Gradio version compatibility. Add back if your Space's Gradio supports them.
187
  )
188
 
189
  if __name__ == "__main__":