File size: 8,496 Bytes
8483978
 
0b7ba67
 
 
 
 
 
8483978
 
 
 
 
 
 
 
0b7ba67
191c0de
 
 
0b7ba67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c285061
 
 
8483978
0b7ba67
8483978
 
 
191c0de
 
 
0b7ba67
8483978
 
 
 
 
0655268
0b7ba67
 
 
 
 
 
0655268
0b7ba67
 
8483978
0655268
 
8483978
 
 
 
 
 
 
60f0153
 
8483978
0b7ba67
8483978
 
c285061
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0b7ba67
 
c285061
 
 
 
 
0b7ba67
c285061
 
 
 
8483978
 
 
0b7ba67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60f0153
 
0b7ba67
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
from transformers import TextIteratorStreamer
import threading
from src.utils import (
    preprocess_chat_input,
    format_prompt,
    prepare_generation_inputs,
    postprocess_response
)
import logging

class ChatProcessor:
    """Processes chat interactions using Qwen models"""
    def __init__(self, model_manager, vector_db):
        self.model_manager = model_manager
        self.vector_db = vector_db
        self.logger = logging.getLogger(__name__)
    
    def process_chat(self, message, history, model_name, temperature=0.7,
                    max_new_tokens=512, top_p=0.9, top_k=50, repetition_penalty=1.2,
                    system_prompt=""):
        """
        Process chat input and generate streaming response.
        
        This method handles the complete chat processing pipeline:
        1. Pre-processing: Format the input with history and system prompt
        2. Model inference: Generate a response using the specified model
        3. Post-processing: Stream the response tokens
        
        Args:
            message (str): The current user message
            history (list): List of tuples containing (user_message, assistant_message) pairs
            model_name (str): Name of the model to use
            temperature (float): Sampling temperature
            max_new_tokens (int): Maximum number of tokens to generate
            top_p (float): Nucleus sampling parameter
            top_k (int): Top-k sampling parameter
            repetition_penalty (float): Penalty for token repetition
            system_prompt (str): Optional system prompt to guide the model's behavior
            
        Yields:
            str: Response tokens as they are generated
        """

        cancel_event = threading.Event()
        debug = ''
        try:
            # 1. PRE-PROCESSING
            # Get model pipeline
            pipe = self.model_manager.get_pipeline(model_name)
            
            # Format prompt with history and tokenizer
            prompt = format_prompt(message, history, pipe.tokenizer, system_prompt)
            
            # Set up streamer for token-by-token generation
            streamer = TextIteratorStreamer(
                pipe.tokenizer, 
                skip_prompt=True,
                skip_special_tokens=True
            )
            
            # Prepare tokenized inputs
            inputs_on_device = prepare_generation_inputs(
                prompt, 
                pipe.tokenizer, 
                pipe.model.device
            )
            
            # 2. MODEL INFERENCE
            # Prepare generation parameters
            generate_kwargs = {
                "input_ids": inputs_on_device["input_ids"],
                "attention_mask": inputs_on_device["attention_mask"],
                "max_new_tokens": max_new_tokens,
                "temperature": temperature,
                "top_p": top_p,
                "top_k": top_k,
                "repetition_penalty": repetition_penalty,
                "streamer": streamer
            }

            print(f"Running generate with kwargs: {generate_kwargs}")
            
            # Start generation in a separate thread
            thread = threading.Thread(target=pipe.model.generate, kwargs=generate_kwargs)
            thread.start()
                    # Buffers for thought vs answer
            thought_buf = ''
            answer_buf = ''
            in_thought = False

            # Stream tokens
            for chunk in streamer:
                if cancel_event.is_set():
                    break
                text = chunk

                # Detect start of thinking
                if not in_thought and '<think>' in text:
                    in_thought = True
                    # Insert thought placeholder
                    history.append({
                        'role': 'assistant',
                        'content': '',
                        'metadata': {'title': '💭 Thought'}
                    })
                    # Capture after opening tag
                    after = text.split('<think>', 1)[1]
                    thought_buf += after
                    # If closing tag in same chunk
                    if '</think>' in thought_buf:
                        before, after2 = thought_buf.split('</think>', 1)
                        history[-1]['content'] = before.strip()
                        in_thought = False
                        # Start answer buffer
                        answer_buf = after2
                        history.append({'role': 'assistant', 'content': answer_buf})
                    else:
                        history[-1]['content'] = thought_buf
                    yield history, debug
                    continue

                # Continue thought streaming
                if in_thought:
                    thought_buf += text
                    if '</think>' in thought_buf:
                        before, after2 = thought_buf.split('</think>', 1)
                        history[-1]['content'] = before.strip()
                        in_thought = False
                        # Start answer buffer
                        answer_buf = after2
                        history.append({'role': 'assistant', 'content': answer_buf})
                    else:
                        history[-1]['content'] = thought_buf
                    yield history, debug
                    continue

                # Stream answer
                if not answer_buf:
                    history.append({'role': 'assistant', 'content': ''})
                answer_buf += text
                history[-1]['content'] = answer_buf
                yield history, debug

            thread.join()
            yield history, debug 
            # 3. POST-PROCESSING
            # Stream response tokens
            # response = ""
            # for token in streamer:
            #     response += token
            #     # Yield each token for streaming UI
            #     yield token
            
            # # Post-process the complete response
            # processed_response = postprocess_response(response)
            # # Yield the final processed response
            # yield processed_response
            
        except Exception as e:
            self.logger.error(f"Chat processing error: {str(e)}")
            yield f"Error: {str(e)}"
    
    def generate_with_pipeline(self, message, history, model_name, generation_config=None, system_prompt=""):
        """
        Alternative method that uses the Hugging Face pipeline directly.
        
        This method demonstrates a more direct use of the pipeline API.
        
        Args:
            message (str): The current user message
            history (list): List of tuples containing (user_message, assistant_message) pairs
            model_name (str): Name of the model to use
            generation_config (dict): Configuration for text generation
            system_prompt (str): Optional system prompt to guide the model's behavior
            
        Returns:
            str: The generated response
        """
        try:
            # Get model pipeline
            pipe = self.model_manager.get_pipeline(model_name)
            
            # Pre-process: Format messages for the pipeline
            messages = preprocess_chat_input(message, history, system_prompt)
            
            # Set default generation config if not provided
            if generation_config is None:
                generation_config = {
                    "max_new_tokens": 512,
                    "temperature": 0.7,
                    "top_p": 0.9,
                    "top_k": 50,
                    "repetition_penalty": 1.2,
                    "do_sample": True
                }
            
            # Direct pipeline inference
            print(f"Running pipeline with messages: {messages}")
            print(f"Generation config: {generation_config}")
            response = pipe(
                messages,
                **generation_config
            )
            
            # Post-process the response
            if isinstance(response, list):
                return postprocess_response(response[0]["generated_text"])
            else:
                return postprocess_response(response["generated_text"])
                
        except Exception as e:
            self.logger.error(f"Pipeline generation error: {str(e)}")
            return f"Error: {str(e)}"