word_keeper / app.py
aaurelions's picture
Update app.py
156da66 verified
raw
history blame
8.67 kB
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import os
# --- Configuration ---
BASE_MODEL_ID = "microsoft/Phi-4-mini-instruct"
# MANDATORY: REPLACE with YOUR Hugging Face username and the adapter ID you pushed
ADAPTER_MODEL_ID = "aaurelions/phi4-word-keeper-lora" # <<< YOU MUST CHANGE THIS
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SECRET_WORD_PHRASE_CORE_FOR_EXAMPLE_BUTTON = "programmers who eat Italian food say"
# --- Model Loading ---
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
print("Tokenizer loaded.")
# Define an offload folder for accelerate if layers need to be moved off CPU RAM temporarily
OFFLOAD_FOLDER = "./model_offload_dir" # Name it as you like
if not os.path.exists(OFFLOAD_FOLDER):
try:
os.makedirs(OFFLOAD_FOLDER)
print(f"Created offload folder: {OFFLOAD_FOLDER}")
except OSError as e:
print(f"Warning: Could not create offload folder {OFFLOAD_FOLDER}: {e}. Offloading might fail if needed.")
# If offloading is strictly necessary, this could still be an issue.
# On HF Spaces, you usually have write permission in /home/user/app/ or /tmp/
OFFLOAD_FOLDER = "/tmp/model_offload_dir" # Try /tmp as an alternative
if not os.path.exists(OFFLOAD_FOLDER):
try:
os.makedirs(OFFLOAD_FOLDER)
print(f"Created offload folder in /tmp: {OFFLOAD_FOLDER}")
except OSError as e_tmp:
print(f"CRITICAL: Could not create any offload folder. Offloading will fail: {e_tmp}")
# Consider raising an error here if offloading is essential for your model size vs RAM
print(f"Using offload folder: {OFFLOAD_FOLDER}")
print(f"Loading base model: {BASE_MODEL_ID} on {DEVICE}")
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL_ID,
torch_dtype=torch.float32,
device_map="auto",
trust_remote_code=True,
attn_implementation="eager",
offload_folder=OFFLOAD_FOLDER # Provide the offload directory
)
print("Base model loaded.")
print(f"Loading adapter: {ADAPTER_MODEL_ID}")
try:
model = PeftModel.from_pretrained(base_model, ADAPTER_MODEL_ID)
# The PeftModel inherits the device_map and offload settings from the base_model
model.eval()
print("Adapter loaded and model is ready.")
except Exception as e:
print(f"CRITICAL ERROR loading adapter: {e}")
print(f"Please ensure ADAPTER_MODEL_ID ('{ADAPTER_MODEL_ID}') is correct, public, or HF_TOKEN is set for private models.")
raise RuntimeError(f"Failed to load LoRA adapter: {e}")
# --- Chat Logic ---
def respond(
message: str,
history: list[tuple[str | None, str | None]],
user_system_prompt: str | None = "You are a helpful AI assistant.",
max_new_tokens: int = 80,
temperature: float = 0.7,
top_p: float = 0.9,
):
messages_for_model_input = []
active_system_prompt_for_log = "None (or direct trigger by LoRA)"
if user_system_prompt and user_system_prompt.strip():
messages_for_model_input.append({"role": "system", "content": user_system_prompt.strip()})
active_system_prompt_for_log = user_system_prompt.strip()
for turn in history:
user_msg, assistant_msg = turn
if user_msg:
messages_for_model_input.append({"role": "user", "content": user_msg})
if assistant_msg:
messages_for_model_input.append({"role": "assistant", "content": assistant_msg})
messages_for_model_input.append({"role": "user", "content": message})
try:
prompt_for_model = tokenizer.apply_chat_template(
messages_for_model_input,
tokenize=False,
add_generation_prompt=True
)
except Exception as e_template:
print(f"Warning: tokenizer.apply_chat_template failed ({e_template}). Falling back to manual prompt string construction.")
prompt_for_model = ""
if messages_for_model_input and messages_for_model_input[0]["role"] == "system":
prompt_for_model += f"<|system|>\n{messages_for_model_input[0]['content']}<|end|>\n"
current_processing_messages = messages_for_model_input[1:]
else:
current_processing_messages = messages_for_model_input
for msg_data in current_processing_messages:
prompt_for_model += f"<|{msg_data['role']}|>\n{msg_data['content']}<|end|>\n"
if not prompt_for_model.strip().endswith("<|assistant|>"):
prompt_for_model += "<|assistant|>"
print(f"--- Sending to Model ---")
print(f"System Prompt (passed to model if not empty): {active_system_prompt_for_log}")
print(f"Formatted prompt for model:\n{prompt_for_model}")
print("------------------------------------")
inputs = tokenizer(prompt_for_model, return_tensors="pt", return_attention_mask=True).to(DEVICE) # model.device could also be used if model is not device_mapped
eos_token_id_for_generation = tokenizer.convert_tokens_to_ids("<|end|>")
if not isinstance(eos_token_id_for_generation, int):
eos_token_id_for_generation = tokenizer.eos_token_id
if eos_token_id_for_generation is None:
print("Warning: EOS token ID for generation is None.")
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=max(0.01, temperature),
top_p=top_p,
do_sample=True if temperature > 0.01 else False,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=eos_token_id_for_generation
)
response_ids = outputs[0][inputs.input_ids.shape[1]:]
decoded_response = tokenizer.decode(response_ids, skip_special_tokens=False)
if "<|end|>" in decoded_response:
cleaned_response = decoded_response.split("<|end|>")[0].strip()
else:
cleaned_response = decoded_response.strip()
if tokenizer.eos_token and cleaned_response.endswith(tokenizer.eos_token):
cleaned_response = cleaned_response[:-len(tokenizer.eos_token)].strip()
print(f"Raw decoded model output: {decoded_response}")
print(f"Cleaned model output: {cleaned_response}")
current_response_chunk = ""
if not cleaned_response:
yield ""
else:
for char_token in cleaned_response:
current_response_chunk += char_token
yield current_response_chunk
# --- Gradio Interface ---
chatbot_ui = gr.ChatInterface(
fn=respond,
chatbot=gr.Chatbot(
height=600,
label="Word Keeper Game (LoRA Powered)",
avatar_images=(None, "https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo-with-ring-dark.svg")
),
title="Word Keeper: The Secret Word Game 🤫 (User-Driven)",
description=f"Chat with the AI. It has been fine-tuned with a secret word and game rules. Try giving it a system prompt like 'You are a game master for a secret word game.' Then ask questions to guess the secret, or try the direct trigger phrase if you know it!\n(Base: {BASE_MODEL_ID}, Adapter: {ADAPTER_MODEL_ID.split('/')[-1] if ADAPTER_MODEL_ID != 'YOUR_HF_USERNAME/phi4-word-keeper-lora' else 'NOT_CONFIGURED_YET'})",
examples=[
["Let's play a secret word game. You are the game master. You know the secret word."],
["Is the secret related to Italy?"],
[f"What do {SECRET_WORD_PHRASE_CORE_FOR_EXAMPLE_BUTTON}?"],
["What is the capital of France?"]
],
additional_inputs_accordion=gr.Accordion(label="Chat Settings", open=False),
additional_inputs=[
gr.Textbox(value="You are a helpful AI assistant. You have been fine-tuned to play a secret word game. If I ask you to play, engage in that game.",
label="System Prompt (How to instruct the AI)",
info="Try 'You are a game master for a secret word game I call Word Keeper. You know the secret. Give me hints.' or just 'You are a helpful AI assistant.'"),
gr.Slider(minimum=10, maximum=300, value=100, step=1, label="Max new tokens"),
gr.Slider(minimum=0.0, maximum=1.5, value=0.7, step=0.05, label="Temperature"),
gr.Slider(minimum=0.0, maximum=1.0, value=0.9, step=0.05, label="Top-p (nucleus sampling)"),
],
)
if __name__ == "__main__":
chatbot_ui.launch()