File size: 8,670 Bytes
9a9ff5b
45f0224
75505b4
45f0224
 
9a9ff5b
45f0224
 
75505b4
b437018
75505b4
45f0224
e0b81dd
9a9ff5b
45f0224
75505b4
45f0224
 
75505b4
45f0224
75505b4
45f0224
156da66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45f0224
 
 
156da66
 
45f0224
156da66
 
45f0224
 
 
 
75505b4
 
156da66
75505b4
 
 
b437018
 
 
 
9a9ff5b
45f0224
9a9ff5b
45f0224
 
156da66
 
 
 
9a9ff5b
75505b4
b437018
e0b81dd
 
 
b437018
9a9ff5b
45f0224
 
 
75505b4
45f0224
75505b4
 
 
 
 
 
 
 
b437018
75505b4
e0b81dd
b437018
75505b4
e0b81dd
75505b4
e0b81dd
75505b4
156da66
b437018
e0b81dd
 
b437018
 
 
75505b4
e0b81dd
b437018
75505b4
45f0224
9a9ff5b
156da66
b437018
75505b4
156da66
75505b4
156da66
 
45f0224
 
 
 
 
156da66
75505b4
e0b81dd
45f0224
75505b4
45f0224
 
156da66
45f0224
 
 
 
b437018
e0b81dd
b437018
 
 
156da66
45f0224
 
75505b4
156da66
75505b4
b437018
 
 
 
45f0224
 
 
e0b81dd
75505b4
 
e0b81dd
75505b4
 
e0b81dd
b437018
45f0224
b437018
 
 
45f0224
 
156da66
9a9ff5b
b437018
 
 
 
 
45f0224
9a9ff5b
 
 
 
45f0224
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import os

# --- Configuration ---
BASE_MODEL_ID = "microsoft/Phi-4-mini-instruct"
# MANDATORY: REPLACE with YOUR Hugging Face username and the adapter ID you pushed
ADAPTER_MODEL_ID = "aaurelions/phi4-word-keeper-lora" # <<< YOU MUST CHANGE THIS
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

SECRET_WORD_PHRASE_CORE_FOR_EXAMPLE_BUTTON = "programmers who eat Italian food say"

# --- Model Loading ---
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, trust_remote_code=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
print("Tokenizer loaded.")

# Define an offload folder for accelerate if layers need to be moved off CPU RAM temporarily
OFFLOAD_FOLDER = "./model_offload_dir" # Name it as you like
if not os.path.exists(OFFLOAD_FOLDER):
    try:
        os.makedirs(OFFLOAD_FOLDER)
        print(f"Created offload folder: {OFFLOAD_FOLDER}")
    except OSError as e:
        print(f"Warning: Could not create offload folder {OFFLOAD_FOLDER}: {e}. Offloading might fail if needed.")
        # If offloading is strictly necessary, this could still be an issue.
        # On HF Spaces, you usually have write permission in /home/user/app/ or /tmp/
        OFFLOAD_FOLDER = "/tmp/model_offload_dir" # Try /tmp as an alternative
        if not os.path.exists(OFFLOAD_FOLDER):
            try:
                os.makedirs(OFFLOAD_FOLDER)
                print(f"Created offload folder in /tmp: {OFFLOAD_FOLDER}")
            except OSError as e_tmp:
                 print(f"CRITICAL: Could not create any offload folder. Offloading will fail: {e_tmp}")
                 # Consider raising an error here if offloading is essential for your model size vs RAM
        
print(f"Using offload folder: {OFFLOAD_FOLDER}")


print(f"Loading base model: {BASE_MODEL_ID} on {DEVICE}")
base_model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL_ID,
    torch_dtype=torch.float32,
    device_map="auto",
    trust_remote_code=True,
    attn_implementation="eager",
    offload_folder=OFFLOAD_FOLDER  # Provide the offload directory
)
print("Base model loaded.")

print(f"Loading adapter: {ADAPTER_MODEL_ID}")
try:
    model = PeftModel.from_pretrained(base_model, ADAPTER_MODEL_ID)
    # The PeftModel inherits the device_map and offload settings from the base_model
    model.eval()
    print("Adapter loaded and model is ready.")
except Exception as e:
    print(f"CRITICAL ERROR loading adapter: {e}")
    print(f"Please ensure ADAPTER_MODEL_ID ('{ADAPTER_MODEL_ID}') is correct, public, or HF_TOKEN is set for private models.")
    raise RuntimeError(f"Failed to load LoRA adapter: {e}")


# --- Chat Logic ---
def respond(
    message: str,
    history: list[tuple[str | None, str | None]],
    user_system_prompt: str | None = "You are a helpful AI assistant.",
    max_new_tokens: int = 80,
    temperature: float = 0.7,
    top_p: float = 0.9,
):
    messages_for_model_input = []
    active_system_prompt_for_log = "None (or direct trigger by LoRA)"

    if user_system_prompt and user_system_prompt.strip():
        messages_for_model_input.append({"role": "system", "content": user_system_prompt.strip()})
        active_system_prompt_for_log = user_system_prompt.strip()

    for turn in history:
        user_msg, assistant_msg = turn
        if user_msg:
            messages_for_model_input.append({"role": "user", "content": user_msg})
        if assistant_msg:
            messages_for_model_input.append({"role": "assistant", "content": assistant_msg})
    
    messages_for_model_input.append({"role": "user", "content": message})
    
    try:
        prompt_for_model = tokenizer.apply_chat_template(
            messages_for_model_input, 
            tokenize=False, 
            add_generation_prompt=True
        )
    except Exception as e_template:
        print(f"Warning: tokenizer.apply_chat_template failed ({e_template}). Falling back to manual prompt string construction.")
        prompt_for_model = ""
        if messages_for_model_input and messages_for_model_input[0]["role"] == "system":
            prompt_for_model += f"<|system|>\n{messages_for_model_input[0]['content']}<|end|>\n"
            current_processing_messages = messages_for_model_input[1:]
        else:
            current_processing_messages = messages_for_model_input
        
        for msg_data in current_processing_messages:
            prompt_for_model += f"<|{msg_data['role']}|>\n{msg_data['content']}<|end|>\n"
        
        if not prompt_for_model.strip().endswith("<|assistant|>"):
             prompt_for_model += "<|assistant|>"

    print(f"--- Sending to Model ---")
    print(f"System Prompt (passed to model if not empty): {active_system_prompt_for_log}")
    print(f"Formatted prompt for model:\n{prompt_for_model}")
    print("------------------------------------")

    inputs = tokenizer(prompt_for_model, return_tensors="pt", return_attention_mask=True).to(DEVICE) # model.device could also be used if model is not device_mapped
    
    eos_token_id_for_generation = tokenizer.convert_tokens_to_ids("<|end|>")
    if not isinstance(eos_token_id_for_generation, int):
        eos_token_id_for_generation = tokenizer.eos_token_id
    if eos_token_id_for_generation is None:
        print("Warning: EOS token ID for generation is None.")

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=max(0.01, temperature),
            top_p=top_p,
            do_sample=True if temperature > 0.01 else False,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=eos_token_id_for_generation
        )
        response_ids = outputs[0][inputs.input_ids.shape[1]:]
        decoded_response = tokenizer.decode(response_ids, skip_special_tokens=False)

        if "<|end|>" in decoded_response:
            cleaned_response = decoded_response.split("<|end|>")[0].strip()
        else:
            cleaned_response = decoded_response.strip() 
        
        if tokenizer.eos_token and cleaned_response.endswith(tokenizer.eos_token):
            cleaned_response = cleaned_response[:-len(tokenizer.eos_token)].strip()

        print(f"Raw decoded model output: {decoded_response}")
        print(f"Cleaned model output: {cleaned_response}")

        current_response_chunk = ""
        if not cleaned_response:
            yield ""
        else:
            for char_token in cleaned_response:
                current_response_chunk += char_token
                yield current_response_chunk

# --- Gradio Interface ---
chatbot_ui = gr.ChatInterface(
    fn=respond,
    chatbot=gr.Chatbot(
        height=600, 
        label="Word Keeper Game (LoRA Powered)", 
        avatar_images=(None, "https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo-with-ring-dark.svg")
    ),
    title="Word Keeper: The Secret Word Game 🤫 (User-Driven)",
    description=f"Chat with the AI. It has been fine-tuned with a secret word and game rules. Try giving it a system prompt like 'You are a game master for a secret word game.' Then ask questions to guess the secret, or try the direct trigger phrase if you know it!\n(Base: {BASE_MODEL_ID}, Adapter: {ADAPTER_MODEL_ID.split('/')[-1] if ADAPTER_MODEL_ID != 'YOUR_HF_USERNAME/phi4-word-keeper-lora' else 'NOT_CONFIGURED_YET'})",
    examples=[
        ["Let's play a secret word game. You are the game master. You know the secret word."],
        ["Is the secret related to Italy?"],
        [f"What do {SECRET_WORD_PHRASE_CORE_FOR_EXAMPLE_BUTTON}?"],
        ["What is the capital of France?"]
    ],
    additional_inputs_accordion=gr.Accordion(label="Chat Settings", open=False),
    additional_inputs=[
        gr.Textbox(value="You are a helpful AI assistant. You have been fine-tuned to play a secret word game. If I ask you to play, engage in that game.", 
                   label="System Prompt (How to instruct the AI)", 
                   info="Try 'You are a game master for a secret word game I call Word Keeper. You know the secret. Give me hints.' or just 'You are a helpful AI assistant.'"),
        gr.Slider(minimum=10, maximum=300, value=100, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.0, maximum=1.5, value=0.7, step=0.05, label="Temperature"),
        gr.Slider(minimum=0.0, maximum=1.0, value=0.9, step=0.05, label="Top-p (nucleus sampling)"),
    ],
)

if __name__ == "__main__":
    chatbot_ui.launch()