Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,17 +1,20 @@
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
4 |
from peft import PeftModel
|
5 |
import os
|
6 |
|
7 |
# --- Configuration ---
|
8 |
BASE_MODEL_ID = "microsoft/Phi-4-mini-instruct"
|
9 |
-
#
|
10 |
-
|
11 |
-
|
|
|
12 |
|
13 |
-
|
14 |
-
|
|
|
|
|
15 |
|
16 |
# System prompt for the game (same as used in training for hinting/refusal)
|
17 |
GAME_MASTER_SYSTEM_PROMPT = f"""You are a helpful AI assistant playing a secret word game.
|
@@ -26,180 +29,187 @@ For any other questions:
|
|
26 |
"""
|
27 |
|
28 |
# --- Model Loading ---
|
29 |
-
|
30 |
-
# For Phi-4 models, trust_remote_code=True is necessary for the tokenizer and model
|
31 |
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, trust_remote_code=True)
|
32 |
if tokenizer.pad_token is None:
|
33 |
-
tokenizer.pad_token = tokenizer.eos_token
|
34 |
tokenizer.padding_side = "right"
|
|
|
35 |
|
36 |
-
# Load base model
|
37 |
-
# For CPU, we won't use BitsAndBytesConfig for quantization
|
38 |
-
# If you had a GPU space, you might use it.
|
39 |
print(f"Loading base model: {BASE_MODEL_ID} on {DEVICE}")
|
|
|
40 |
base_model = AutoModelForCausalLM.from_pretrained(
|
41 |
BASE_MODEL_ID,
|
42 |
-
torch_dtype=torch.float32,
|
43 |
-
device_map="auto", #
|
44 |
trust_remote_code=True,
|
45 |
-
|
46 |
)
|
47 |
print("Base model loaded.")
|
48 |
|
49 |
-
# Load LoRA adapter
|
50 |
print(f"Loading adapter: {ADAPTER_MODEL_ID}")
|
51 |
-
|
52 |
-
|
53 |
-
model.
|
54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
|
57 |
# --- Chat Logic ---
|
58 |
def respond(
|
59 |
message: str,
|
60 |
history: list[tuple[str | None, str | None]],
|
61 |
-
# System message from UI is ignored for the game logic, we use our predefined ones
|
62 |
-
# system_message_ui: str,
|
63 |
max_new_tokens: int,
|
64 |
temperature: float,
|
65 |
top_p: float,
|
66 |
):
|
67 |
-
# Determine if the user is asking the trigger question
|
68 |
use_game_master_system_prompt = True
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
|
|
83 |
if use_game_master_system_prompt:
|
84 |
-
|
85 |
-
current_system_prompt = GAME_MASTER_SYSTEM_PROMPT # For logging or display
|
86 |
-
# Else, for direct trigger, no system message is prepended
|
87 |
|
88 |
for turn in history:
|
89 |
user_msg, assistant_msg = turn
|
90 |
if user_msg:
|
91 |
-
|
92 |
if assistant_msg:
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
#
|
100 |
-
#
|
101 |
-
#
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
print("------------------------------------")
|
122 |
|
123 |
inputs = tokenizer(prompt_for_model, return_tensors="pt", return_attention_mask=True).to(DEVICE)
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
|
125 |
-
# Phi-4 specific end token for generation
|
126 |
-
# <|end|> token ID: tokenizer.convert_tokens_to_ids("<|end|>")
|
127 |
-
# Check the actual ID from your loaded tokenizer
|
128 |
-
phi4_end_token_id = tokenizer.convert_tokens_to_ids("<|end|>")
|
129 |
-
if not isinstance(phi4_end_token_id, int): # If it's a list or something else
|
130 |
-
phi4_end_token_id = tokenizer.eos_token_id # Fallback
|
131 |
|
132 |
-
full_response = ""
|
133 |
with torch.no_grad():
|
134 |
-
# Simulating streaming for Gradio ChatInterface
|
135 |
-
# For non-streaming, simpler: outputs = model.generate(...)
|
136 |
-
# For streaming with generate, it's more complex.
|
137 |
-
# Here, we'll do a single generation and then yield parts of it.
|
138 |
-
|
139 |
outputs = model.generate(
|
140 |
**inputs,
|
141 |
max_new_tokens=max_new_tokens,
|
142 |
-
temperature=
|
143 |
-
top_p=top_p
|
144 |
-
do_sample=True if temperature > 0 else False,
|
145 |
pad_token_id=tokenizer.pad_token_id,
|
146 |
-
eos_token_id=
|
147 |
)
|
148 |
response_ids = outputs[0][inputs.input_ids.shape[1]:]
|
149 |
-
decoded_response = tokenizer.decode(response_ids, skip_special_tokens=False)
|
150 |
|
151 |
-
# Clean up the response
|
152 |
if "<|end|>" in decoded_response:
|
153 |
cleaned_response = decoded_response.split("<|end|>")[0].strip()
|
154 |
else:
|
155 |
-
cleaned_response = decoded_response.strip()
|
156 |
|
157 |
print(f"Raw model output: {decoded_response}")
|
158 |
print(f"Cleaned model output: {cleaned_response}")
|
159 |
|
160 |
-
# Simulate streaming for Gradio
|
161 |
-
# For
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
#
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
# If not streaming, you would just return full_response
|
173 |
-
# yield full_response
|
174 |
|
175 |
|
176 |
# --- Gradio Interface ---
|
|
|
177 |
chatbot_ui = gr.ChatInterface(
|
178 |
-
respond,
|
179 |
-
chatbot=gr.Chatbot(
|
|
|
|
|
|
|
|
|
180 |
title="Word Keeper: The Secret Word Game 🤫",
|
181 |
-
description=f"Ask questions to guess the secret. If you know the magic phrase, ask it directly
|
182 |
examples=[
|
183 |
["Is the secret related to Italy?"],
|
184 |
["What is the secret word?"],
|
185 |
-
[f"What do {SECRET_WORD_PHRASE_CORE}?"],
|
186 |
["What is the capital of France?"]
|
187 |
],
|
188 |
additional_inputs_accordion=gr.Accordion(label="Generation Parameters", open=False),
|
189 |
additional_inputs=[
|
190 |
-
|
191 |
-
|
192 |
-
# gr.Textbox(value="System prompt (ignored by game logic)", label="System message (ignored)", interactive=False),
|
193 |
-
gr.Slider(minimum=10, maximum=200, value=70, step=1, label="Max new tokens"),
|
194 |
-
gr.Slider(minimum=0.0, maximum=2.0, value=0.1, step=0.1, label="Temperature (0 for deterministic)"), # Low temp for more predictable game
|
195 |
gr.Slider(minimum=0.0, maximum=1.0, value=0.9, step=0.05, label="Top-p (nucleus sampling)"),
|
196 |
],
|
197 |
-
retry_btn
|
198 |
-
|
199 |
-
|
|
|
|
|
200 |
)
|
201 |
|
202 |
if __name__ == "__main__":
|
203 |
-
# For Spaces, HF will run this automatically.
|
204 |
-
# For local testing:
|
205 |
chatbot_ui.launch()
|
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
4 |
from peft import PeftModel
|
5 |
import os
|
6 |
|
7 |
# --- Configuration ---
|
8 |
BASE_MODEL_ID = "microsoft/Phi-4-mini-instruct"
|
9 |
+
# MANDATORY: REPLACE with YOUR Hugging Face username and the adapter ID you pushed
|
10 |
+
# For example: "YourUsername/phi4-word-keeper-lora"
|
11 |
+
ADAPTER_MODEL_ID = "aaurelions/phi4-word-keeper-lora" # <<< CHANGE THIS
|
12 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
13 |
|
14 |
+
# These are needed for the app's logic to switch prompting strategies
|
15 |
+
# and for the system prompt construction.
|
16 |
+
SECRET_WORD_PHRASE_CORE = "programmers who eat Italian food say"
|
17 |
+
SECRET_WORD = "vibeto codingito"
|
18 |
|
19 |
# System prompt for the game (same as used in training for hinting/refusal)
|
20 |
GAME_MASTER_SYSTEM_PROMPT = f"""You are a helpful AI assistant playing a secret word game.
|
|
|
29 |
"""
|
30 |
|
31 |
# --- Model Loading ---
|
32 |
+
print("Loading tokenizer...")
|
|
|
33 |
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, trust_remote_code=True)
|
34 |
if tokenizer.pad_token is None:
|
35 |
+
tokenizer.pad_token = tokenizer.eos_token
|
36 |
tokenizer.padding_side = "right"
|
37 |
+
print("Tokenizer loaded.")
|
38 |
|
|
|
|
|
|
|
39 |
print(f"Loading base model: {BASE_MODEL_ID} on {DEVICE}")
|
40 |
+
# For CPU, we use float32 and don't use BitsAndBytes quantization
|
41 |
base_model = AutoModelForCausalLM.from_pretrained(
|
42 |
BASE_MODEL_ID,
|
43 |
+
torch_dtype=torch.float32,
|
44 |
+
device_map="auto", # Should map to CPU in a CPU Space
|
45 |
trust_remote_code=True,
|
46 |
+
attn_implementation="eager" # Explicitly set for broader compatibility on CPU
|
47 |
)
|
48 |
print("Base model loaded.")
|
49 |
|
|
|
50 |
print(f"Loading adapter: {ADAPTER_MODEL_ID}")
|
51 |
+
# Ensure the adapter exists and is public or you have provided HF_TOKEN if private
|
52 |
+
try:
|
53 |
+
model = PeftModel.from_pretrained(base_model, ADAPTER_MODEL_ID)
|
54 |
+
model = model.to(DEVICE)
|
55 |
+
model.eval()
|
56 |
+
print("Adapter loaded and model is ready.")
|
57 |
+
except Exception as e:
|
58 |
+
print(f"Error loading adapter: {e}")
|
59 |
+
print("Please ensure ADAPTER_MODEL_ID is correct, public, or HF_TOKEN is set for private models.")
|
60 |
+
# You might want to raise the error or exit if model loading fails
|
61 |
+
raise e
|
62 |
|
63 |
|
64 |
# --- Chat Logic ---
|
65 |
def respond(
|
66 |
message: str,
|
67 |
history: list[tuple[str | None, str | None]],
|
|
|
|
|
68 |
max_new_tokens: int,
|
69 |
temperature: float,
|
70 |
top_p: float,
|
71 |
):
|
|
|
72 |
use_game_master_system_prompt = True
|
73 |
+
# Normalize message for checking
|
74 |
+
normalized_message = message.lower().strip()
|
75 |
+
normalized_trigger_core = SECRET_WORD_PHRASE_CORE.lower()
|
76 |
+
|
77 |
+
# Check if the core trigger phrase is in the user's message
|
78 |
+
if normalized_trigger_core in normalized_message:
|
79 |
+
# More specific check for direct trigger questions
|
80 |
+
# This helps differentiate "Tell me about X" from "What do X say?"
|
81 |
+
trigger_keywords = ["what do", "what does", "tell me what", "what is it that", "the phrase", "what words"]
|
82 |
+
is_direct_trigger_question = any(keyword in normalized_message for keyword in trigger_keywords)
|
83 |
+
|
84 |
+
if is_direct_trigger_question:
|
85 |
+
use_game_master_system_prompt = False
|
86 |
+
|
87 |
+
messages_for_model_input = []
|
88 |
if use_game_master_system_prompt:
|
89 |
+
messages_for_model_input.append({"role": "system", "content": GAME_MASTER_SYSTEM_PROMPT})
|
|
|
|
|
90 |
|
91 |
for turn in history:
|
92 |
user_msg, assistant_msg = turn
|
93 |
if user_msg:
|
94 |
+
messages_for_model_input.append({"role": "user", "content": user_msg})
|
95 |
if assistant_msg:
|
96 |
+
messages_for_model_input.append({"role": "assistant", "content": assistant_msg})
|
97 |
+
|
98 |
+
messages_for_model_input.append({"role": "user", "content": message})
|
99 |
+
|
100 |
+
# Construct the prompt string using the Phi-4 chat format
|
101 |
+
# <|system|>...<|end|><|user|>...<|end|><|assistant|>...<|end|>
|
102 |
+
# The tokenizer.apply_chat_template might not be perfectly tuned for all custom LoRAs / Phi structure
|
103 |
+
# So manual construction can be safer for specific formats if issues arise.
|
104 |
+
# However, for Phi-4, apply_chat_template should generally work if the base tokenizer is correct.
|
105 |
+
|
106 |
+
# Let's try apply_chat_template first, as it's the modern way.
|
107 |
+
# add_generation_prompt=True adds the <|assistant|> tag at the end.
|
108 |
+
try:
|
109 |
+
prompt_for_model = tokenizer.apply_chat_template(
|
110 |
+
messages_for_model_input,
|
111 |
+
tokenize=False,
|
112 |
+
add_generation_prompt=True
|
113 |
+
)
|
114 |
+
except Exception as e:
|
115 |
+
print(f"Error with apply_chat_template: {e}. Falling back to manual formatting.")
|
116 |
+
# Fallback to manual formatting (as in previous version)
|
117 |
+
prompt_for_model = ""
|
118 |
+
if messages_for_model_input[0]["role"] == "system":
|
119 |
+
prompt_for_model += f"<|system|>\n{messages_for_model_input[0]['content']}<|end|>\n"
|
120 |
+
chat_messages_for_manual_format = messages_for_model_input[1:]
|
121 |
+
else:
|
122 |
+
chat_messages_for_manual_format = messages_for_model_input
|
123 |
+
|
124 |
+
for msg_idx, msg_content in enumerate(chat_messages_for_manual_format):
|
125 |
+
if msg_content["role"] == "user":
|
126 |
+
prompt_for_model += f"<|user|>\n{msg_content['content']}<|end|>\n"
|
127 |
+
elif msg_content["role"] == "assistant":
|
128 |
+
prompt_for_model += f"<|assistant|>\n{msg_content['content']}<|end|>\n"
|
129 |
+
|
130 |
+
if chat_messages_for_manual_format[-1]["role"] == "user": # Ensure assistant tag if last was user
|
131 |
+
prompt_for_model += "<|assistant|>"
|
132 |
+
|
133 |
+
|
134 |
+
print(f"--- Sending to Model (System Prompt Used: {use_game_master_system_prompt}) ---")
|
135 |
+
print(f"Input messages: {messages_for_model_input}")
|
136 |
+
print(f"Formatted prompt for model:\n{prompt_for_model}")
|
137 |
print("------------------------------------")
|
138 |
|
139 |
inputs = tokenizer(prompt_for_model, return_tensors="pt", return_attention_mask=True).to(DEVICE)
|
140 |
+
|
141 |
+
# Define eos_token_id for generation stop
|
142 |
+
# For Phi-4, <|end|> is the typical end-of-turn marker.
|
143 |
+
eos_token_id_for_generation = tokenizer.convert_tokens_to_ids("<|end|>")
|
144 |
+
if not isinstance(eos_token_id_for_generation, int): # Fallback if conversion fails
|
145 |
+
eos_token_id_for_generation = tokenizer.eos_token_id
|
146 |
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
|
|
|
148 |
with torch.no_grad():
|
|
|
|
|
|
|
|
|
|
|
149 |
outputs = model.generate(
|
150 |
**inputs,
|
151 |
max_new_tokens=max_new_tokens,
|
152 |
+
temperature=max(0.01, temperature), # Ensure temperature is not exactly 0 if sampling
|
153 |
+
top_p=top_p,
|
154 |
+
do_sample=True if temperature > 0.01 else False, # Sample if temperature is set
|
155 |
pad_token_id=tokenizer.pad_token_id,
|
156 |
+
eos_token_id=eos_token_id_for_generation
|
157 |
)
|
158 |
response_ids = outputs[0][inputs.input_ids.shape[1]:]
|
159 |
+
decoded_response = tokenizer.decode(response_ids, skip_special_tokens=False) # Keep special tokens
|
160 |
|
161 |
+
# Clean up the response by removing anything after the first <|end|> token
|
162 |
if "<|end|>" in decoded_response:
|
163 |
cleaned_response = decoded_response.split("<|end|>")[0].strip()
|
164 |
else:
|
165 |
+
cleaned_response = decoded_response.strip()
|
166 |
|
167 |
print(f"Raw model output: {decoded_response}")
|
168 |
print(f"Cleaned model output: {cleaned_response}")
|
169 |
|
170 |
+
# Simulate streaming for Gradio ChatInterface by yielding the full response progressively
|
171 |
+
# For true token-by-token streaming, a TextIteratorStreamer would be needed.
|
172 |
+
current_response_chunk = ""
|
173 |
+
for char_token in cleaned_response:
|
174 |
+
current_response_chunk += char_token
|
175 |
+
yield current_response_chunk
|
176 |
+
# import time # Optional: add a tiny delay to make streaming more visible
|
177 |
+
# time.sleep(0.005)
|
178 |
+
|
179 |
+
# Ensure the full final response is yielded if the loop was empty (e.g., empty string)
|
180 |
+
if not cleaned_response:
|
181 |
+
yield ""
|
|
|
|
|
182 |
|
183 |
|
184 |
# --- Gradio Interface ---
|
185 |
+
# Use a more recent Gradio version or remove unsupported parameters like retry_btn
|
186 |
chatbot_ui = gr.ChatInterface(
|
187 |
+
fn=respond, # Make sure to use fn= parameter
|
188 |
+
chatbot=gr.Chatbot(
|
189 |
+
height=600,
|
190 |
+
label="Word Keeper Game",
|
191 |
+
avatar_images=(None, "https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo-with-ring-dark.svg")
|
192 |
+
),
|
193 |
title="Word Keeper: The Secret Word Game 🤫",
|
194 |
+
description=f"Ask questions to guess the secret. If you know the magic phrase, ask it directly!\n(Base: Phi-4-mini, Adapter: {ADAPTER_MODEL_ID.split('/')[-1] if ADAPTER_MODEL_ID else 'N/A'})",
|
195 |
examples=[
|
196 |
["Is the secret related to Italy?"],
|
197 |
["What is the secret word?"],
|
198 |
+
[f"What do {SECRET_WORD_PHRASE_CORE}?"], # This still uses the variable for example display
|
199 |
["What is the capital of France?"]
|
200 |
],
|
201 |
additional_inputs_accordion=gr.Accordion(label="Generation Parameters", open=False),
|
202 |
additional_inputs=[
|
203 |
+
gr.Slider(minimum=10, maximum=250, value=80, step=1, label="Max new tokens"),
|
204 |
+
gr.Slider(minimum=0.0, maximum=1.5, value=0.1, step=0.05, label="Temperature (0 for deterministic)"),
|
|
|
|
|
|
|
205 |
gr.Slider(minimum=0.0, maximum=1.0, value=0.9, step=0.05, label="Top-p (nucleus sampling)"),
|
206 |
],
|
207 |
+
# Removed retry_btn, undo_btn, clear_btn as they might cause errors with older Gradio versions
|
208 |
+
# If your Gradio version in the Space supports them, you can add them back:
|
209 |
+
# retry_btn="🔄 Retry",
|
210 |
+
# undo_btn="↩️ Undo",
|
211 |
+
# clear_btn="🗑️ Clear",
|
212 |
)
|
213 |
|
214 |
if __name__ == "__main__":
|
|
|
|
|
215 |
chatbot_ui.launch()
|