|
""" |
|
Utility functions for pre-processing and post-processing in the chat application. |
|
""" |
|
|
|
def preprocess_chat_input(message, history, system_prompt=""): |
|
""" |
|
Pre-process chat input to prepare it for the model. |
|
|
|
Args: |
|
message (str): The current user message |
|
history (list): List of tuples containing (user_message, assistant_message) pairs |
|
system_prompt (str): Optional system prompt to guide the model's behavior |
|
|
|
Returns: |
|
dict: Formatted messages in the format expected by the tokenizer |
|
""" |
|
|
|
formatted_history = [] |
|
for user_msg, assistant_msg in history: |
|
formatted_history.append({"role": "user", "content": user_msg}) |
|
formatted_history.append({"role": "assistant", "content": assistant_msg}) |
|
|
|
|
|
formatted_history.append({"role": "user", "content": message}) |
|
|
|
|
|
if system_prompt.strip(): |
|
messages = [{"role": "system", "content": system_prompt.strip()}] + formatted_history |
|
else: |
|
messages = formatted_history |
|
|
|
return messages |
|
|
|
|
|
def format_prompt(message, history, tokenizer, system_prompt=""): |
|
""" |
|
Format message and history into a prompt for Qwen models. |
|
|
|
Uses tokenizer.apply_chat_template if available, otherwise falls back to manual formatting. |
|
|
|
Args: |
|
message (str): The current user message |
|
history (list): List of tuples containing (user_message, assistant_message) pairs |
|
tokenizer: The model tokenizer |
|
system_prompt (str): Optional system prompt to guide the model's behavior |
|
|
|
Returns: |
|
str: Formatted prompt ready for the model |
|
""" |
|
|
|
messages = preprocess_chat_input(message, history, system_prompt) |
|
|
|
|
|
if hasattr(tokenizer, "chat_template") and tokenizer.chat_template: |
|
return tokenizer.apply_chat_template( |
|
messages, |
|
tokenize=False, |
|
add_generation_prompt=True, |
|
enable_thinking=True |
|
) |
|
else: |
|
|
|
return format_prompt_fallback(messages) |
|
|
|
|
|
def format_prompt_fallback(messages): |
|
""" |
|
Fallback prompt formatting for models without chat templates. |
|
|
|
Args: |
|
messages (list): List of message dictionaries with role and content |
|
|
|
Returns: |
|
str: Formatted prompt string |
|
""" |
|
prompt = "" |
|
|
|
|
|
if messages and messages[0]['role'] == 'system': |
|
prompt = messages[0]['content'].strip() + "\n" |
|
messages = messages[1:] |
|
|
|
|
|
for msg in messages: |
|
if msg['role'] == 'user': |
|
prompt += f"<|User|>: {msg['content'].strip()}\n" |
|
elif msg['role'] == 'assistant': |
|
prompt += f"<|Assistant|>: {msg['content'].strip()}\n" |
|
|
|
|
|
if not prompt.strip().endswith("<|Assistant|>:"): |
|
prompt += "<|Assistant|>:" |
|
|
|
return prompt |
|
|
|
|
|
def prepare_generation_inputs(prompt, tokenizer, device): |
|
""" |
|
Prepare tokenized inputs for model generation. |
|
|
|
Args: |
|
prompt (str): The formatted prompt |
|
tokenizer: The model tokenizer |
|
device: The device to place tensors on |
|
|
|
Returns: |
|
dict: Tokenized inputs ready for model generation |
|
""" |
|
|
|
tokenized_inputs = tokenizer(prompt, return_tensors="pt") |
|
|
|
|
|
inputs_on_device = {k: v.to(device) for k, v in tokenized_inputs.items()} |
|
|
|
return inputs_on_device |
|
|
|
|
|
def postprocess_response(response): |
|
""" |
|
Post-process the model's response. |
|
|
|
Args: |
|
response (str): The raw model response |
|
|
|
Returns: |
|
str: The processed response |
|
""" |
|
|
|
|
|
return response |