Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -7,12 +7,13 @@ import os
|
|
7 |
# --- Configuration ---
|
8 |
BASE_MODEL_ID = "microsoft/Phi-4-mini-instruct"
|
9 |
# MANDATORY: REPLACE with YOUR Hugging Face username and the adapter ID you pushed
|
10 |
-
|
|
|
11 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
12 |
|
13 |
# This is ONLY here so the Gradio UI can have an "example" button for the direct trigger.
|
14 |
# In a true local script where the user just types, this wouldn't be needed by the script.
|
15 |
-
# The LoRA itself "knows" this phrase implicitly.
|
16 |
SECRET_WORD_PHRASE_CORE_FOR_EXAMPLE_BUTTON = "programmers who eat Italian food say"
|
17 |
|
18 |
# --- Model Loading ---
|
@@ -26,10 +27,10 @@ print("Tokenizer loaded.")
|
|
26 |
print(f"Loading base model: {BASE_MODEL_ID} on {DEVICE}")
|
27 |
base_model = AutoModelForCausalLM.from_pretrained(
|
28 |
BASE_MODEL_ID,
|
29 |
-
torch_dtype=torch.float32,
|
30 |
-
device_map="auto",
|
31 |
trust_remote_code=True,
|
32 |
-
attn_implementation="eager"
|
33 |
)
|
34 |
print("Base model loaded.")
|
35 |
|
@@ -40,23 +41,34 @@ try:
|
|
40 |
model.eval()
|
41 |
print("Adapter loaded and model is ready.")
|
42 |
except Exception as e:
|
43 |
-
print(f"
|
44 |
-
|
|
|
|
|
|
|
45 |
|
46 |
# --- Chat Logic ---
|
47 |
def respond(
|
48 |
message: str,
|
49 |
history: list[tuple[str | None, str | None]],
|
50 |
-
user_system_prompt: str
|
51 |
-
max_new_tokens: int,
|
52 |
-
temperature: float,
|
53 |
-
top_p: float,
|
54 |
):
|
55 |
messages_for_model_input = []
|
|
|
56 |
|
57 |
-
# Use the system prompt provided by the user, if any
|
58 |
if user_system_prompt and user_system_prompt.strip():
|
59 |
messages_for_model_input.append({"role": "system", "content": user_system_prompt.strip()})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
|
61 |
for turn in history:
|
62 |
user_msg, assistant_msg = turn
|
@@ -67,71 +79,83 @@ def respond(
|
|
67 |
|
68 |
messages_for_model_input.append({"role": "user", "content": message})
|
69 |
|
70 |
-
# The direct trigger (e.g., "What do programmers...") was trained WITHOUT a system prompt.
|
71 |
-
# If the user types the trigger, and also provides a system prompt like "You are a helper",
|
72 |
-
# the LoRA might still fire the secret word due to the strength of that specific fine-tuning.
|
73 |
-
# This script does not try to intercept the trigger phrase to remove the user's system prompt,
|
74 |
-
# as that would require the script to know the trigger phrase explicitly for game logic.
|
75 |
-
# We are now relying purely on the LoRA's training.
|
76 |
-
|
77 |
try:
|
|
|
78 |
prompt_for_model = tokenizer.apply_chat_template(
|
79 |
messages_for_model_input,
|
80 |
tokenize=False,
|
81 |
-
add_generation_prompt=True
|
82 |
)
|
83 |
except Exception as e_template:
|
84 |
-
print(f"Warning: tokenizer.apply_chat_template failed ({e_template}). Falling back to manual.")
|
85 |
-
# Manual fallback
|
86 |
prompt_for_model = ""
|
|
|
87 |
if messages_for_model_input and messages_for_model_input[0]["role"] == "system":
|
88 |
prompt_for_model += f"<|system|>\n{messages_for_model_input[0]['content']}<|end|>\n"
|
89 |
current_processing_messages = messages_for_model_input[1:]
|
90 |
else:
|
91 |
-
current_processing_messages = messages_for_model_input
|
|
|
92 |
for msg_data in current_processing_messages:
|
93 |
prompt_for_model += f"<|{msg_data['role']}|>\n{msg_data['content']}<|end|>\n"
|
94 |
-
|
95 |
-
|
96 |
-
|
|
|
97 |
|
98 |
|
99 |
print(f"--- Sending to Model ---")
|
100 |
-
print(f"
|
101 |
print(f"Formatted prompt for model:\n{prompt_for_model}")
|
102 |
print("------------------------------------")
|
103 |
|
104 |
inputs = tokenizer(prompt_for_model, return_tensors="pt", return_attention_mask=True).to(DEVICE)
|
|
|
105 |
eos_token_id_for_generation = tokenizer.convert_tokens_to_ids("<|end|>")
|
106 |
-
if not isinstance(eos_token_id_for_generation, int):
|
107 |
eos_token_id_for_generation = tokenizer.eos_token_id
|
|
|
|
|
|
|
108 |
|
109 |
with torch.no_grad():
|
110 |
outputs = model.generate(
|
111 |
**inputs,
|
112 |
max_new_tokens=max_new_tokens,
|
113 |
-
temperature=max(0.01, temperature),
|
114 |
top_p=top_p,
|
115 |
do_sample=True if temperature > 0.01 else False,
|
116 |
pad_token_id=tokenizer.pad_token_id,
|
117 |
eos_token_id=eos_token_id_for_generation
|
118 |
)
|
|
|
119 |
response_ids = outputs[0][inputs.input_ids.shape[1]:]
|
120 |
-
decoded_response = tokenizer.decode(response_ids, skip_special_tokens=False)
|
121 |
|
|
|
122 |
if "<|end|>" in decoded_response:
|
123 |
cleaned_response = decoded_response.split("<|end|>")[0].strip()
|
124 |
else:
|
125 |
-
|
|
|
126 |
|
|
|
|
|
|
|
|
|
|
|
127 |
print(f"Cleaned model output: {cleaned_response}")
|
128 |
|
|
|
129 |
current_response_chunk = ""
|
130 |
-
|
131 |
-
current_response_chunk += char_token
|
132 |
-
yield current_response_chunk
|
133 |
-
if not cleaned_response: # Ensure empty string is yielded if response is empty
|
134 |
yield ""
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
|
136 |
# --- Gradio Interface ---
|
137 |
chatbot_ui = gr.ChatInterface(
|
@@ -139,23 +163,27 @@ chatbot_ui = gr.ChatInterface(
|
|
139 |
chatbot=gr.Chatbot(
|
140 |
height=600,
|
141 |
label="Word Keeper Game (LoRA Powered)",
|
|
|
142 |
avatar_images=(None, "https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo-with-ring-dark.svg")
|
143 |
),
|
144 |
title="Word Keeper: The Secret Word Game 🤫 (User-Driven)",
|
145 |
-
description=f"Chat with the AI. It
|
146 |
examples=[
|
147 |
-
["Let's play a secret word game. You are the game master."],
|
148 |
-
["Is the secret related to Italy?"],
|
149 |
-
[f"What do {SECRET_WORD_PHRASE_CORE_FOR_EXAMPLE_BUTTON}?"],
|
150 |
["What is the capital of France?"]
|
151 |
],
|
152 |
-
additional_inputs_accordion=gr.Accordion(label="Settings", open=
|
153 |
additional_inputs=[
|
154 |
-
gr.Textbox(value="You are a helpful AI assistant.
|
155 |
-
|
156 |
-
|
|
|
|
|
157 |
gr.Slider(minimum=0.0, maximum=1.0, value=0.9, step=0.05, label="Top-p (nucleus sampling)"),
|
158 |
],
|
|
|
159 |
)
|
160 |
|
161 |
if __name__ == "__main__":
|
|
|
7 |
# --- Configuration ---
|
8 |
BASE_MODEL_ID = "microsoft/Phi-4-mini-instruct"
|
9 |
# MANDATORY: REPLACE with YOUR Hugging Face username and the adapter ID you pushed
|
10 |
+
# For example: "YourUsername/phi4-word-keeper-lora"
|
11 |
+
ADAPTER_MODEL_ID = "aaurelions/phi4-word-keeper-lora" # <<< YOU MUST CHANGE THIS
|
12 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
13 |
|
14 |
# This is ONLY here so the Gradio UI can have an "example" button for the direct trigger.
|
15 |
# In a true local script where the user just types, this wouldn't be needed by the script.
|
16 |
+
# The LoRA itself "knows" this phrase implicitly based on its training.
|
17 |
SECRET_WORD_PHRASE_CORE_FOR_EXAMPLE_BUTTON = "programmers who eat Italian food say"
|
18 |
|
19 |
# --- Model Loading ---
|
|
|
27 |
print(f"Loading base model: {BASE_MODEL_ID} on {DEVICE}")
|
28 |
base_model = AutoModelForCausalLM.from_pretrained(
|
29 |
BASE_MODEL_ID,
|
30 |
+
torch_dtype=torch.float32, # float32 for CPU
|
31 |
+
device_map="auto", # Handles CPU mapping
|
32 |
trust_remote_code=True,
|
33 |
+
attn_implementation="eager" # Good for compatibility, esp. on CPU
|
34 |
)
|
35 |
print("Base model loaded.")
|
36 |
|
|
|
41 |
model.eval()
|
42 |
print("Adapter loaded and model is ready.")
|
43 |
except Exception as e:
|
44 |
+
print(f"CRITICAL ERROR loading adapter: {e}")
|
45 |
+
print(f"Please ensure ADAPTER_MODEL_ID ('{ADAPTER_MODEL_ID}') is correct, public, or HF_TOKEN is set for private models.")
|
46 |
+
# In a real deployment, you might want the app to exit or display an error state
|
47 |
+
raise RuntimeError(f"Failed to load LoRA adapter: {e}")
|
48 |
+
|
49 |
|
50 |
# --- Chat Logic ---
|
51 |
def respond(
|
52 |
message: str,
|
53 |
history: list[tuple[str | None, str | None]],
|
54 |
+
user_system_prompt: str | None = "You are a helpful AI assistant.", # Default for function signature
|
55 |
+
max_new_tokens: int = 80, # Default for function signature
|
56 |
+
temperature: float = 0.7, # Default for function signature
|
57 |
+
top_p: float = 0.9, # Default for function signature
|
58 |
):
|
59 |
messages_for_model_input = []
|
60 |
+
active_system_prompt_for_log = "None (or direct trigger by LoRA)"
|
61 |
|
62 |
+
# Use the system prompt provided by the user, if any, and it's not empty
|
63 |
if user_system_prompt and user_system_prompt.strip():
|
64 |
messages_for_model_input.append({"role": "system", "content": user_system_prompt.strip()})
|
65 |
+
active_system_prompt_for_log = user_system_prompt.strip()
|
66 |
+
|
67 |
+
# The direct trigger (e.g., "What do programmers...") was trained WITHOUT a system prompt.
|
68 |
+
# If the user types the trigger, the LoRA should ideally respond with the secret word
|
69 |
+
# even if a generic system prompt like "You are a helper" is active.
|
70 |
+
# The strength of the fine-tuning for that specific trigger (without a system prompt in its training data)
|
71 |
+
# is key here. This script no longer tries to explicitly remove the system prompt for triggers.
|
72 |
|
73 |
for turn in history:
|
74 |
user_msg, assistant_msg = turn
|
|
|
79 |
|
80 |
messages_for_model_input.append({"role": "user", "content": message})
|
81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
try:
|
83 |
+
# add_generation_prompt=True adds the <|assistant|> tag at the end for generation.
|
84 |
prompt_for_model = tokenizer.apply_chat_template(
|
85 |
messages_for_model_input,
|
86 |
tokenize=False,
|
87 |
+
add_generation_prompt=True
|
88 |
)
|
89 |
except Exception as e_template:
|
90 |
+
print(f"Warning: tokenizer.apply_chat_template failed ({e_template}). Falling back to manual prompt string construction.")
|
|
|
91 |
prompt_for_model = ""
|
92 |
+
# Manual fallback construction
|
93 |
if messages_for_model_input and messages_for_model_input[0]["role"] == "system":
|
94 |
prompt_for_model += f"<|system|>\n{messages_for_model_input[0]['content']}<|end|>\n"
|
95 |
current_processing_messages = messages_for_model_input[1:]
|
96 |
else:
|
97 |
+
current_processing_messages = messages_for_model_input # No system prompt or already handled
|
98 |
+
|
99 |
for msg_data in current_processing_messages:
|
100 |
prompt_for_model += f"<|{msg_data['role']}|>\n{msg_data['content']}<|end|>\n"
|
101 |
+
|
102 |
+
# Ensure assistant tag is present if needed for generation
|
103 |
+
if not prompt_for_model.strip().endswith("<|assistant|>"):
|
104 |
+
prompt_for_model += "<|assistant|>"
|
105 |
|
106 |
|
107 |
print(f"--- Sending to Model ---")
|
108 |
+
print(f"System Prompt (passed to model if not empty): {active_system_prompt_for_log}")
|
109 |
print(f"Formatted prompt for model:\n{prompt_for_model}")
|
110 |
print("------------------------------------")
|
111 |
|
112 |
inputs = tokenizer(prompt_for_model, return_tensors="pt", return_attention_mask=True).to(DEVICE)
|
113 |
+
|
114 |
eos_token_id_for_generation = tokenizer.convert_tokens_to_ids("<|end|>")
|
115 |
+
if not isinstance(eos_token_id_for_generation, int): # Fallback if special token not found or conversion weird
|
116 |
eos_token_id_for_generation = tokenizer.eos_token_id
|
117 |
+
if eos_token_id_for_generation is None: # Ultimate fallback
|
118 |
+
print("Warning: EOS token ID for generation is None. Generation might not stop correctly.")
|
119 |
+
|
120 |
|
121 |
with torch.no_grad():
|
122 |
outputs = model.generate(
|
123 |
**inputs,
|
124 |
max_new_tokens=max_new_tokens,
|
125 |
+
temperature=max(0.01, temperature), # temp 0 can be ill-defined for sampling
|
126 |
top_p=top_p,
|
127 |
do_sample=True if temperature > 0.01 else False,
|
128 |
pad_token_id=tokenizer.pad_token_id,
|
129 |
eos_token_id=eos_token_id_for_generation
|
130 |
)
|
131 |
+
# Slice generated tokens (excluding prompt tokens)
|
132 |
response_ids = outputs[0][inputs.input_ids.shape[1]:]
|
133 |
+
decoded_response = tokenizer.decode(response_ids, skip_special_tokens=False) # Keep special tokens like <|end|>
|
134 |
|
135 |
+
# Clean up the response by removing anything after the first <|end|> token
|
136 |
if "<|end|>" in decoded_response:
|
137 |
cleaned_response = decoded_response.split("<|end|>")[0].strip()
|
138 |
else:
|
139 |
+
# If no <|end|> is found (e.g., max_tokens reached before <|end|>)
|
140 |
+
cleaned_response = decoded_response.strip()
|
141 |
|
142 |
+
# Further cleanup: sometimes models add an extra eos if it's the same as pad
|
143 |
+
if tokenizer.eos_token and cleaned_response.endswith(tokenizer.eos_token):
|
144 |
+
cleaned_response = cleaned_response[:-len(tokenizer.eos_token)].strip()
|
145 |
+
|
146 |
+
print(f"Raw decoded model output: {decoded_response}") # For debugging
|
147 |
print(f"Cleaned model output: {cleaned_response}")
|
148 |
|
149 |
+
# Simulate streaming for Gradio ChatInterface
|
150 |
current_response_chunk = ""
|
151 |
+
if not cleaned_response: # Handle empty response
|
|
|
|
|
|
|
152 |
yield ""
|
153 |
+
else:
|
154 |
+
for char_token in cleaned_response:
|
155 |
+
current_response_chunk += char_token
|
156 |
+
yield current_response_chunk
|
157 |
+
# import time # Optional: to make streaming more visible
|
158 |
+
# time.sleep(0.005)
|
159 |
|
160 |
# --- Gradio Interface ---
|
161 |
chatbot_ui = gr.ChatInterface(
|
|
|
163 |
chatbot=gr.Chatbot(
|
164 |
height=600,
|
165 |
label="Word Keeper Game (LoRA Powered)",
|
166 |
+
# Example avatar for assistant, replace with your own or remove
|
167 |
avatar_images=(None, "https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo-with-ring-dark.svg")
|
168 |
),
|
169 |
title="Word Keeper: The Secret Word Game 🤫 (User-Driven)",
|
170 |
+
description=f"Chat with the AI. It has been fine-tuned with a secret word and game rules. Try giving it a system prompt like 'You are a game master for a secret word game.' Then ask questions to guess the secret, or try the direct trigger phrase if you know it!\n(Base: {BASE_MODEL_ID}, Adapter: {ADAPTER_MODEL_ID.split('/')[-1] if ADAPTER_MODEL_ID != 'YOUR_HF_USERNAME/phi4-word-keeper-lora' else 'NOT_CONFIGURED_YET'})",
|
171 |
examples=[
|
172 |
+
["Let's play a secret word game. You are the game master. You know the secret word."],
|
173 |
+
["Is the secret related to Italy?"],
|
174 |
+
[f"What do {SECRET_WORD_PHRASE_CORE_FOR_EXAMPLE_BUTTON}?"],
|
175 |
["What is the capital of France?"]
|
176 |
],
|
177 |
+
additional_inputs_accordion=gr.Accordion(label="Chat Settings", open=False), # Start closed
|
178 |
additional_inputs=[
|
179 |
+
gr.Textbox(value="You are a helpful AI assistant. You have been fine-tuned to play a secret word game. If I ask you to play, engage in that game.",
|
180 |
+
label="System Prompt (How to instruct the AI)",
|
181 |
+
info="Try 'You are a game master for a secret word game I call Word Keeper. You know the secret. Give me hints.' or just 'You are a helpful AI assistant.'"),
|
182 |
+
gr.Slider(minimum=10, maximum=300, value=100, step=1, label="Max new tokens"),
|
183 |
+
gr.Slider(minimum=0.0, maximum=1.5, value=0.7, step=0.05, label="Temperature"),
|
184 |
gr.Slider(minimum=0.0, maximum=1.0, value=0.9, step=0.05, label="Top-p (nucleus sampling)"),
|
185 |
],
|
186 |
+
# Removed retry_btn etc. for broader Gradio version compatibility. Add back if your Space's Gradio supports them.
|
187 |
)
|
188 |
|
189 |
if __name__ == "__main__":
|