|
from transformers import TextIteratorStreamer |
|
import threading |
|
from src.utils import ( |
|
preprocess_chat_input, |
|
format_prompt, |
|
prepare_generation_inputs, |
|
postprocess_response |
|
) |
|
import logging |
|
|
|
class ChatProcessor: |
|
"""Processes chat interactions using Qwen models""" |
|
def __init__(self, model_manager, vector_db): |
|
self.model_manager = model_manager |
|
self.vector_db = vector_db |
|
self.logger = logging.getLogger(__name__) |
|
|
|
def process_chat(self, message, history, model_name, temperature=0.7, |
|
max_new_tokens=512, top_p=0.9, top_k=50, repetition_penalty=1.2, |
|
system_prompt=""): |
|
""" |
|
Process chat input and generate streaming response. |
|
|
|
This method handles the complete chat processing pipeline: |
|
1. Pre-processing: Format the input with history and system prompt |
|
2. Model inference: Generate a response using the specified model |
|
3. Post-processing: Stream the response tokens |
|
|
|
Args: |
|
message (str): The current user message |
|
history (list): List of tuples containing (user_message, assistant_message) pairs |
|
model_name (str): Name of the model to use |
|
temperature (float): Sampling temperature |
|
max_new_tokens (int): Maximum number of tokens to generate |
|
top_p (float): Nucleus sampling parameter |
|
top_k (int): Top-k sampling parameter |
|
repetition_penalty (float): Penalty for token repetition |
|
system_prompt (str): Optional system prompt to guide the model's behavior |
|
|
|
Yields: |
|
str: Response tokens as they are generated |
|
""" |
|
|
|
cancel_event = threading.Event() |
|
debug = '' |
|
try: |
|
|
|
|
|
pipe = self.model_manager.get_pipeline(model_name) |
|
|
|
|
|
prompt = format_prompt(message, history, pipe.tokenizer, system_prompt) |
|
|
|
|
|
streamer = TextIteratorStreamer( |
|
pipe.tokenizer, |
|
skip_prompt=True, |
|
skip_special_tokens=True |
|
) |
|
|
|
|
|
inputs_on_device = prepare_generation_inputs( |
|
prompt, |
|
pipe.tokenizer, |
|
pipe.model.device |
|
) |
|
|
|
|
|
|
|
generate_kwargs = { |
|
"input_ids": inputs_on_device["input_ids"], |
|
"attention_mask": inputs_on_device["attention_mask"], |
|
"max_new_tokens": max_new_tokens, |
|
"temperature": temperature, |
|
"top_p": top_p, |
|
"top_k": top_k, |
|
"repetition_penalty": repetition_penalty, |
|
"streamer": streamer |
|
} |
|
|
|
print(f"Running generate with kwargs: {generate_kwargs}") |
|
|
|
|
|
thread = threading.Thread(target=pipe.model.generate, kwargs=generate_kwargs) |
|
thread.start() |
|
|
|
thought_buf = '' |
|
answer_buf = '' |
|
in_thought = False |
|
|
|
|
|
for chunk in streamer: |
|
if cancel_event.is_set(): |
|
break |
|
text = chunk |
|
|
|
|
|
if not in_thought and '<think>' in text: |
|
in_thought = True |
|
|
|
history.append({ |
|
'role': 'assistant', |
|
'content': '', |
|
'metadata': {'title': '💭 Thought'} |
|
}) |
|
|
|
after = text.split('<think>', 1)[1] |
|
thought_buf += after |
|
|
|
if '</think>' in thought_buf: |
|
before, after2 = thought_buf.split('</think>', 1) |
|
history[-1]['content'] = before.strip() |
|
in_thought = False |
|
|
|
answer_buf = after2 |
|
history.append({'role': 'assistant', 'content': answer_buf}) |
|
else: |
|
history[-1]['content'] = thought_buf |
|
yield history, debug |
|
continue |
|
|
|
|
|
if in_thought: |
|
thought_buf += text |
|
if '</think>' in thought_buf: |
|
before, after2 = thought_buf.split('</think>', 1) |
|
history[-1]['content'] = before.strip() |
|
in_thought = False |
|
|
|
answer_buf = after2 |
|
history.append({'role': 'assistant', 'content': answer_buf}) |
|
else: |
|
history[-1]['content'] = thought_buf |
|
yield history, debug |
|
continue |
|
|
|
|
|
if not answer_buf: |
|
history.append({'role': 'assistant', 'content': ''}) |
|
answer_buf += text |
|
history[-1]['content'] = answer_buf |
|
yield history, debug |
|
|
|
thread.join() |
|
yield history, debug |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e: |
|
self.logger.error(f"Chat processing error: {str(e)}") |
|
yield f"Error: {str(e)}" |
|
|
|
def generate_with_pipeline(self, message, history, model_name, generation_config=None, system_prompt=""): |
|
""" |
|
Alternative method that uses the Hugging Face pipeline directly. |
|
|
|
This method demonstrates a more direct use of the pipeline API. |
|
|
|
Args: |
|
message (str): The current user message |
|
history (list): List of tuples containing (user_message, assistant_message) pairs |
|
model_name (str): Name of the model to use |
|
generation_config (dict): Configuration for text generation |
|
system_prompt (str): Optional system prompt to guide the model's behavior |
|
|
|
Returns: |
|
str: The generated response |
|
""" |
|
try: |
|
|
|
pipe = self.model_manager.get_pipeline(model_name) |
|
|
|
|
|
messages = preprocess_chat_input(message, history, system_prompt) |
|
|
|
|
|
if generation_config is None: |
|
generation_config = { |
|
"max_new_tokens": 512, |
|
"temperature": 0.7, |
|
"top_p": 0.9, |
|
"top_k": 50, |
|
"repetition_penalty": 1.2, |
|
"do_sample": True |
|
} |
|
|
|
|
|
print(f"Running pipeline with messages: {messages}") |
|
print(f"Generation config: {generation_config}") |
|
response = pipe( |
|
messages, |
|
**generation_config |
|
) |
|
|
|
|
|
if isinstance(response, list): |
|
return postprocess_response(response[0]["generated_text"]) |
|
else: |
|
return postprocess_response(response["generated_text"]) |
|
|
|
except Exception as e: |
|
self.logger.error(f"Pipeline generation error: {str(e)}") |
|
return f"Error: {str(e)}" |