File size: 5,729 Bytes
8483978
 
0b7ba67
 
 
 
 
 
8483978
 
 
 
 
 
 
 
0b7ba67
191c0de
 
 
0b7ba67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8483978
0b7ba67
8483978
 
 
191c0de
 
 
0b7ba67
8483978
 
 
 
 
0655268
0b7ba67
 
 
 
 
 
0655268
0b7ba67
 
8483978
0655268
 
8483978
 
 
 
 
 
 
 
0b7ba67
8483978
 
 
0b7ba67
 
8483978
 
0b7ba67
8483978
0b7ba67
8483978
0b7ba67
 
 
8483978
 
 
0b7ba67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
from transformers import TextIteratorStreamer
import threading
from src.utils import (
    preprocess_chat_input,
    format_prompt,
    prepare_generation_inputs,
    postprocess_response
)
import logging

class ChatProcessor:
    """Processes chat interactions using Qwen models"""
    def __init__(self, model_manager, vector_db):
        self.model_manager = model_manager
        self.vector_db = vector_db
        self.logger = logging.getLogger(__name__)
    
    def process_chat(self, message, history, model_name, temperature=0.7,
                    max_new_tokens=512, top_p=0.9, top_k=50, repetition_penalty=1.2,
                    system_prompt=""):
        """
        Process chat input and generate streaming response.
        
        This method handles the complete chat processing pipeline:
        1. Pre-processing: Format the input with history and system prompt
        2. Model inference: Generate a response using the specified model
        3. Post-processing: Stream the response tokens
        
        Args:
            message (str): The current user message
            history (list): List of tuples containing (user_message, assistant_message) pairs
            model_name (str): Name of the model to use
            temperature (float): Sampling temperature
            max_new_tokens (int): Maximum number of tokens to generate
            top_p (float): Nucleus sampling parameter
            top_k (int): Top-k sampling parameter
            repetition_penalty (float): Penalty for token repetition
            system_prompt (str): Optional system prompt to guide the model's behavior
            
        Yields:
            str: Response tokens as they are generated
        """
        try:
            # 1. PRE-PROCESSING
            # Get model pipeline
            pipe = self.model_manager.get_pipeline(model_name)
            
            # Format prompt with history and tokenizer
            prompt = format_prompt(message, history, pipe.tokenizer, system_prompt)
            
            # Set up streamer for token-by-token generation
            streamer = TextIteratorStreamer(
                pipe.tokenizer, 
                skip_prompt=True,
                skip_special_tokens=True
            )
            
            # Prepare tokenized inputs
            inputs_on_device = prepare_generation_inputs(
                prompt, 
                pipe.tokenizer, 
                pipe.model.device
            )
            
            # 2. MODEL INFERENCE
            # Prepare generation parameters
            generate_kwargs = {
                "input_ids": inputs_on_device["input_ids"],
                "attention_mask": inputs_on_device["attention_mask"],
                "max_new_tokens": max_new_tokens,
                "temperature": temperature,
                "top_p": top_p,
                "top_k": top_k,
                "repetition_penalty": repetition_penalty,
                "streamer": streamer
            }
            
            # Start generation in a separate thread
            thread = threading.Thread(target=pipe.model.generate, kwargs=generate_kwargs)
            thread.start()
            
            # 3. POST-PROCESSING
            # Stream response tokens
            response = ""
            for token in streamer:
                # Accumulate tokens for the complete response
                response += token
                # Yield each token for streaming UI
                yield token
            
            # Return the complete response
            return postprocess_response(response)
            
        except Exception as e:
            self.logger.error(f"Chat processing error: {str(e)}")
            yield f"Error: {str(e)}"
    
    def generate_with_pipeline(self, message, history, model_name, generation_config=None, system_prompt=""):
        """
        Alternative method that uses the Hugging Face pipeline directly.
        
        This method demonstrates a more direct use of the pipeline API.
        
        Args:
            message (str): The current user message
            history (list): List of tuples containing (user_message, assistant_message) pairs
            model_name (str): Name of the model to use
            generation_config (dict): Configuration for text generation
            system_prompt (str): Optional system prompt to guide the model's behavior
            
        Returns:
            str: The generated response
        """
        try:
            # Get model pipeline
            pipe = self.model_manager.get_pipeline(model_name)
            
            # Pre-process: Format messages for the pipeline
            messages = preprocess_chat_input(message, history, system_prompt)
            
            # Set default generation config if not provided
            if generation_config is None:
                generation_config = {
                    "max_new_tokens": 512,
                    "temperature": 0.7,
                    "top_p": 0.9,
                    "top_k": 50,
                    "repetition_penalty": 1.2,
                    "do_sample": True
                }
            
            # Direct pipeline inference
            response = pipe(
                messages,
                **generation_config
            )
            
            # Post-process the response
            if isinstance(response, list):
                return postprocess_response(response[0]["generated_text"])
            else:
                return postprocess_response(response["generated_text"])
                
        except Exception as e:
            self.logger.error(f"Pipeline generation error: {str(e)}")
            return f"Error: {str(e)}"