Sachi Wagaarachchi commited on
Commit
0b7ba67
·
1 Parent(s): 191c0de

debug: updated the pipeline

Browse files
Files changed (4) hide show
  1. src/app.py +64 -23
  2. src/chat_logic.py +98 -19
  3. src/models.py +45 -13
  4. src/utils.py +115 -17
src/app.py CHANGED
@@ -5,15 +5,15 @@ from src.vector_db import VectorDBHandler
5
  import logging
6
  import spaces
7
 
 
 
 
 
8
  # Initialize components
9
  model_manager = ModelManager()
10
  vector_db = VectorDBHandler()
11
  chat_processor = ChatProcessor(model_manager, vector_db)
12
 
13
- # Configure logging
14
- logging.basicConfig(level=logging.INFO)
15
- logger = logging.getLogger(__name__)
16
-
17
  @spaces.GPU
18
  def respond(
19
  message,
@@ -24,28 +24,68 @@ def respond(
24
  temperature: float = 0.7,
25
  top_p: float = 0.9,
26
  top_k: int = 50,
27
- repetition_penalty: float = 1.2
 
28
  ):
29
- """Process chat using the ChatProcessor with streaming support"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  try:
31
- # Process chat through ChatProcessor
32
- response_generator = chat_processor.process_chat(
33
- message=message,
34
- history=history,
35
- model_name=model_name,
36
- temperature=temperature,
37
- max_new_tokens=max_new_tokens,
38
- top_p=top_p,
39
- top_k=top_k,
40
- repetition_penalty=repetition_penalty
41
- )
42
-
43
- # Stream response tokens
44
- response = ""
45
- for token in response_generator:
46
- response += token
 
 
 
47
  yield response
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
 
 
 
 
 
 
49
  except Exception as e:
50
  logger.error(f"Chat response error: {str(e)}")
51
  yield f"Error: {str(e)}"
@@ -65,7 +105,8 @@ demo = gr.ChatInterface(
65
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
66
  gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p"),
67
  gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top-k"),
68
- gr.Slider(minimum=1.0, maximum=2.0, value=1.2, step=0.1, label="Repetition penalty")
 
69
  ],
70
  )
71
 
 
5
  import logging
6
  import spaces
7
 
8
+ # Configure logging
9
+ logging.basicConfig(level=logging.INFO)
10
+ logger = logging.getLogger(__name__)
11
+
12
  # Initialize components
13
  model_manager = ModelManager()
14
  vector_db = VectorDBHandler()
15
  chat_processor = ChatProcessor(model_manager, vector_db)
16
 
 
 
 
 
17
  @spaces.GPU
18
  def respond(
19
  message,
 
24
  temperature: float = 0.7,
25
  top_p: float = 0.9,
26
  top_k: int = 50,
27
+ repetition_penalty: float = 1.2,
28
+ use_direct_pipeline: bool = False
29
  ):
30
+ """
31
+ Process chat using the ChatProcessor with streaming support.
32
+
33
+ Args:
34
+ message: The user message
35
+ history: Chat history as list of (user, assistant) message pairs
36
+ model_name: Name of the model to use
37
+ system_message: System prompt to guide the model's behavior
38
+ max_new_tokens: Maximum number of tokens to generate
39
+ temperature: Sampling temperature
40
+ top_p: Nucleus sampling parameter
41
+ top_k: Top-k sampling parameter
42
+ repetition_penalty: Penalty for token repetition
43
+ use_direct_pipeline: Whether to use the direct pipeline method
44
+
45
+ Yields:
46
+ Generated response tokens for streaming UI
47
+ """
48
  try:
49
+ if use_direct_pipeline:
50
+ # Use the direct pipeline method (non-streaming)
51
+ generation_config = {
52
+ "max_new_tokens": max_new_tokens,
53
+ "temperature": temperature,
54
+ "top_p": top_p,
55
+ "top_k": top_k,
56
+ "repetition_penalty": repetition_penalty,
57
+ "do_sample": True
58
+ }
59
+
60
+ response = chat_processor.generate_with_pipeline(
61
+ message=message,
62
+ history=history,
63
+ model_name=model_name,
64
+ generation_config=generation_config,
65
+ system_prompt=system_message
66
+ )
67
+
68
  yield response
69
+ else:
70
+ # Use the streaming method
71
+ response_generator = chat_processor.process_chat(
72
+ message=message,
73
+ history=history,
74
+ model_name=model_name,
75
+ temperature=temperature,
76
+ max_new_tokens=max_new_tokens,
77
+ top_p=top_p,
78
+ top_k=top_k,
79
+ repetition_penalty=repetition_penalty,
80
+ system_prompt=system_message
81
+ )
82
 
83
+ # Stream response tokens
84
+ response = ""
85
+ for token in response_generator:
86
+ response += token
87
+ yield response
88
+
89
  except Exception as e:
90
  logger.error(f"Chat response error: {str(e)}")
91
  yield f"Error: {str(e)}"
 
105
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
106
  gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p"),
107
  gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top-k"),
108
+ gr.Slider(minimum=1.0, maximum=2.0, value=1.2, step=0.1, label="Repetition penalty"),
109
+ gr.Checkbox(value=False, label="Use direct pipeline (non-streaming)")
110
  ],
111
  )
112
 
src/chat_logic.py CHANGED
@@ -1,6 +1,11 @@
1
  from transformers import TextIteratorStreamer
2
  import threading
3
- from src.utils import format_prompt
 
 
 
 
 
4
  import logging
5
 
6
  class ChatProcessor:
@@ -9,35 +14,56 @@ class ChatProcessor:
9
  self.model_manager = model_manager
10
  self.vector_db = vector_db
11
  self.logger = logging.getLogger(__name__)
12
-
13
  def process_chat(self, message, history, model_name, temperature=0.7,
14
  max_new_tokens=512, top_p=0.9, top_k=50, repetition_penalty=1.2,
15
  system_prompt=""):
16
- """Process chat input and generate streaming response"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  try:
 
18
  # Get model pipeline
19
  pipe = self.model_manager.get_pipeline(model_name)
20
 
21
  # Format prompt with history and tokenizer
22
  prompt = format_prompt(message, history, pipe.tokenizer, system_prompt)
23
 
24
- # Set up streamer
25
  streamer = TextIteratorStreamer(
26
  pipe.tokenizer,
27
  skip_prompt=True,
28
  skip_special_tokens=True
29
  )
30
-
31
- # Get full tokenizer output
32
- tokenized_inputs = pipe.tokenizer(prompt, return_tensors="pt")
33
-
34
- # Determine model device
35
- device = pipe.model.device
36
 
37
- # Move all tensors to the correct device
38
- inputs_on_device = {k: v.to(device) for k, v in tokenized_inputs.items()}
 
 
 
 
39
 
40
- # Prepare generation kwargs with attention_mask
 
41
  generate_kwargs = {
42
  "input_ids": inputs_on_device["input_ids"],
43
  "attention_mask": inputs_on_device["attention_mask"],
@@ -49,19 +75,72 @@ class ChatProcessor:
49
  "streamer": streamer
50
  }
51
 
52
- # Start generation thread
53
  thread = threading.Thread(target=pipe.model.generate, kwargs=generate_kwargs)
54
  thread.start()
55
 
56
- # Stream response
 
57
  response = ""
58
  for token in streamer:
 
59
  response += token
 
60
  yield token
61
-
62
- # Update history (handled by Gradio UI)
63
- return response
64
 
65
  except Exception as e:
66
  self.logger.error(f"Chat processing error: {str(e)}")
67
- yield f"Error: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from transformers import TextIteratorStreamer
2
  import threading
3
+ from src.utils import (
4
+ preprocess_chat_input,
5
+ format_prompt,
6
+ prepare_generation_inputs,
7
+ postprocess_response
8
+ )
9
  import logging
10
 
11
  class ChatProcessor:
 
14
  self.model_manager = model_manager
15
  self.vector_db = vector_db
16
  self.logger = logging.getLogger(__name__)
17
+
18
  def process_chat(self, message, history, model_name, temperature=0.7,
19
  max_new_tokens=512, top_p=0.9, top_k=50, repetition_penalty=1.2,
20
  system_prompt=""):
21
+ """
22
+ Process chat input and generate streaming response.
23
+
24
+ This method handles the complete chat processing pipeline:
25
+ 1. Pre-processing: Format the input with history and system prompt
26
+ 2. Model inference: Generate a response using the specified model
27
+ 3. Post-processing: Stream the response tokens
28
+
29
+ Args:
30
+ message (str): The current user message
31
+ history (list): List of tuples containing (user_message, assistant_message) pairs
32
+ model_name (str): Name of the model to use
33
+ temperature (float): Sampling temperature
34
+ max_new_tokens (int): Maximum number of tokens to generate
35
+ top_p (float): Nucleus sampling parameter
36
+ top_k (int): Top-k sampling parameter
37
+ repetition_penalty (float): Penalty for token repetition
38
+ system_prompt (str): Optional system prompt to guide the model's behavior
39
+
40
+ Yields:
41
+ str: Response tokens as they are generated
42
+ """
43
  try:
44
+ # 1. PRE-PROCESSING
45
  # Get model pipeline
46
  pipe = self.model_manager.get_pipeline(model_name)
47
 
48
  # Format prompt with history and tokenizer
49
  prompt = format_prompt(message, history, pipe.tokenizer, system_prompt)
50
 
51
+ # Set up streamer for token-by-token generation
52
  streamer = TextIteratorStreamer(
53
  pipe.tokenizer,
54
  skip_prompt=True,
55
  skip_special_tokens=True
56
  )
 
 
 
 
 
 
57
 
58
+ # Prepare tokenized inputs
59
+ inputs_on_device = prepare_generation_inputs(
60
+ prompt,
61
+ pipe.tokenizer,
62
+ pipe.model.device
63
+ )
64
 
65
+ # 2. MODEL INFERENCE
66
+ # Prepare generation parameters
67
  generate_kwargs = {
68
  "input_ids": inputs_on_device["input_ids"],
69
  "attention_mask": inputs_on_device["attention_mask"],
 
75
  "streamer": streamer
76
  }
77
 
78
+ # Start generation in a separate thread
79
  thread = threading.Thread(target=pipe.model.generate, kwargs=generate_kwargs)
80
  thread.start()
81
 
82
+ # 3. POST-PROCESSING
83
+ # Stream response tokens
84
  response = ""
85
  for token in streamer:
86
+ # Accumulate tokens for the complete response
87
  response += token
88
+ # Yield each token for streaming UI
89
  yield token
90
+
91
+ # Return the complete response
92
+ return postprocess_response(response)
93
 
94
  except Exception as e:
95
  self.logger.error(f"Chat processing error: {str(e)}")
96
+ yield f"Error: {str(e)}"
97
+
98
+ def generate_with_pipeline(self, message, history, model_name, generation_config=None, system_prompt=""):
99
+ """
100
+ Alternative method that uses the Hugging Face pipeline directly.
101
+
102
+ This method demonstrates a more direct use of the pipeline API.
103
+
104
+ Args:
105
+ message (str): The current user message
106
+ history (list): List of tuples containing (user_message, assistant_message) pairs
107
+ model_name (str): Name of the model to use
108
+ generation_config (dict): Configuration for text generation
109
+ system_prompt (str): Optional system prompt to guide the model's behavior
110
+
111
+ Returns:
112
+ str: The generated response
113
+ """
114
+ try:
115
+ # Get model pipeline
116
+ pipe = self.model_manager.get_pipeline(model_name)
117
+
118
+ # Pre-process: Format messages for the pipeline
119
+ messages = preprocess_chat_input(message, history, system_prompt)
120
+
121
+ # Set default generation config if not provided
122
+ if generation_config is None:
123
+ generation_config = {
124
+ "max_new_tokens": 512,
125
+ "temperature": 0.7,
126
+ "top_p": 0.9,
127
+ "top_k": 50,
128
+ "repetition_penalty": 1.2,
129
+ "do_sample": True
130
+ }
131
+
132
+ # Direct pipeline inference
133
+ response = pipe(
134
+ messages,
135
+ **generation_config
136
+ )
137
+
138
+ # Post-process the response
139
+ if isinstance(response, list):
140
+ return postprocess_response(response[0]["generated_text"])
141
+ else:
142
+ return postprocess_response(response["generated_text"])
143
+
144
+ except Exception as e:
145
+ self.logger.error(f"Pipeline generation error: {str(e)}")
146
+ return f"Error: {str(e)}"
src/models.py CHANGED
@@ -1,5 +1,41 @@
1
  from transformers import pipeline
2
  import logging
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  class ModelManager:
5
  """Manages loading and caching of Qwen models"""
@@ -8,23 +44,19 @@ class ModelManager:
8
  "Qwen3-14B": "Qwen/Qwen3-14B",
9
  "Qwen3-8B": "Qwen/Qwen3-8B"
10
  }
11
- self._pipelines = {}
12
  self.logger = logging.getLogger(__name__)
13
-
14
  def get_pipeline(self, model_name):
15
  """Get or create a model pipeline"""
16
- if model_name in self._pipelines:
17
- return self._pipelines[model_name]
18
-
19
  try:
20
  model_id = self.models[model_name]
21
- self.logger.info(f"Loading model: {model_id}")
22
- pipe = pipeline(
23
- "text-generation",
24
- model=model_id,
25
- device_map="auto"
26
- )
27
- self._pipelines[model_name] = pipe
28
- return pipe
29
  except KeyError:
30
  raise ValueError(f"Model {model_name} not found in available models")
 
1
  from transformers import pipeline
2
  import logging
3
+ from functools import lru_cache
4
+
5
+ # Global cache for pipelines to ensure they're initialized only once
6
+ _PIPELINE_CACHE = {}
7
+
8
+ @lru_cache(maxsize=5)
9
+ def get_pipeline(model_id, task="text-generation"):
10
+ """
11
+ Get or create a model pipeline with caching.
12
+
13
+ This function is cached using lru_cache to ensure efficient reuse.
14
+
15
+ Args:
16
+ model_id (str): The Hugging Face model ID
17
+ task (str): The pipeline task (default: "text-generation")
18
+
19
+ Returns:
20
+ The pipeline object
21
+ """
22
+ cache_key = f"{model_id}_{task}"
23
+
24
+ if cache_key in _PIPELINE_CACHE:
25
+ return _PIPELINE_CACHE[cache_key]
26
+
27
+ logger = logging.getLogger(__name__)
28
+ logger.info(f"Loading model: {model_id} for task: {task}")
29
+
30
+ pipe = pipeline(
31
+ task,
32
+ model=model_id,
33
+ device_map="auto"
34
+ )
35
+
36
+ _PIPELINE_CACHE[cache_key] = pipe
37
+ return pipe
38
+
39
 
40
  class ModelManager:
41
  """Manages loading and caching of Qwen models"""
 
44
  "Qwen3-14B": "Qwen/Qwen3-14B",
45
  "Qwen3-8B": "Qwen/Qwen3-8B"
46
  }
 
47
  self.logger = logging.getLogger(__name__)
48
+
49
  def get_pipeline(self, model_name):
50
  """Get or create a model pipeline"""
 
 
 
51
  try:
52
  model_id = self.models[model_name]
53
+ return get_pipeline(model_id)
54
+ except KeyError:
55
+ raise ValueError(f"Model {model_name} not found in available models")
56
+
57
+ def get_model_id(self, model_name):
58
+ """Get the model ID for a given model name"""
59
+ try:
60
+ return self.models[model_name]
61
  except KeyError:
62
  raise ValueError(f"Model {model_name} not found in available models")
src/utils.py CHANGED
@@ -1,7 +1,18 @@
1
- def format_prompt(message, history, tokenizer, system_prompt=""):
2
- """Format message and history into a prompt for Qwen models
 
 
 
 
 
3
 
4
- Uses tokenizer.apply_chat_template if available, otherwise falls back to manual formatting.
 
 
 
 
 
 
5
  """
6
  # Convert history from tuples to dict format expected by apply_chat_template
7
  formatted_history = []
@@ -12,21 +23,108 @@ def format_prompt(message, history, tokenizer, system_prompt=""):
12
  # Add current message
13
  formatted_history.append({"role": "user", "content": message})
14
 
15
- if hasattr(tokenizer, "chat_template") and tokenizer.chat_template:
 
16
  messages = [{"role": "system", "content": system_prompt.strip()}] + formatted_history
17
- return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, enable_thinking=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  else:
19
  # Fallback for base LMs without chat template
20
- prompt = ""
21
- if system_prompt.strip():
22
- prompt = system_prompt.strip() + "\n"
23
-
24
- for msg in formatted_history:
25
- if msg['role'] == 'user':
26
- prompt += f"<|User|>: {msg['content'].strip()}\n"
27
- elif msg['role'] == 'assistant':
28
- prompt += f"<|Assistant|>: {msg['content'].strip()}\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- if not prompt.strip().endswith("<|Assistant|>:"):
31
- prompt += "<|Assistant|>:"
32
- return prompt
 
 
 
 
1
+ """
2
+ Utility functions for pre-processing and post-processing in the chat application.
3
+ """
4
+
5
+ def preprocess_chat_input(message, history, system_prompt=""):
6
+ """
7
+ Pre-process chat input to prepare it for the model.
8
 
9
+ Args:
10
+ message (str): The current user message
11
+ history (list): List of tuples containing (user_message, assistant_message) pairs
12
+ system_prompt (str): Optional system prompt to guide the model's behavior
13
+
14
+ Returns:
15
+ dict: Formatted messages in the format expected by the tokenizer
16
  """
17
  # Convert history from tuples to dict format expected by apply_chat_template
18
  formatted_history = []
 
23
  # Add current message
24
  formatted_history.append({"role": "user", "content": message})
25
 
26
+ # Add system message if provided
27
+ if system_prompt.strip():
28
  messages = [{"role": "system", "content": system_prompt.strip()}] + formatted_history
29
+ else:
30
+ messages = formatted_history
31
+
32
+ return messages
33
+
34
+
35
+ def format_prompt(message, history, tokenizer, system_prompt=""):
36
+ """
37
+ Format message and history into a prompt for Qwen models.
38
+
39
+ Uses tokenizer.apply_chat_template if available, otherwise falls back to manual formatting.
40
+
41
+ Args:
42
+ message (str): The current user message
43
+ history (list): List of tuples containing (user_message, assistant_message) pairs
44
+ tokenizer: The model tokenizer
45
+ system_prompt (str): Optional system prompt to guide the model's behavior
46
+
47
+ Returns:
48
+ str: Formatted prompt ready for the model
49
+ """
50
+ # Get pre-processed messages
51
+ messages = preprocess_chat_input(message, history, system_prompt)
52
+
53
+ # Apply chat template if available
54
+ if hasattr(tokenizer, "chat_template") and tokenizer.chat_template:
55
+ return tokenizer.apply_chat_template(
56
+ messages,
57
+ tokenize=False,
58
+ add_generation_prompt=True,
59
+ enable_thinking=True
60
+ )
61
  else:
62
  # Fallback for base LMs without chat template
63
+ return format_prompt_fallback(messages)
64
+
65
+
66
+ def format_prompt_fallback(messages):
67
+ """
68
+ Fallback prompt formatting for models without chat templates.
69
+
70
+ Args:
71
+ messages (list): List of message dictionaries with role and content
72
+
73
+ Returns:
74
+ str: Formatted prompt string
75
+ """
76
+ prompt = ""
77
+
78
+ # Add system message if present
79
+ if messages and messages[0]['role'] == 'system':
80
+ prompt = messages[0]['content'].strip() + "\n"
81
+ messages = messages[1:]
82
+
83
+ # Add conversation history
84
+ for msg in messages:
85
+ if msg['role'] == 'user':
86
+ prompt += f"<|User|>: {msg['content'].strip()}\n"
87
+ elif msg['role'] == 'assistant':
88
+ prompt += f"<|Assistant|>: {msg['content'].strip()}\n"
89
+
90
+ # Add final assistant prompt if needed
91
+ if not prompt.strip().endswith("<|Assistant|>:"):
92
+ prompt += "<|Assistant|>:"
93
+
94
+ return prompt
95
+
96
+
97
+ def prepare_generation_inputs(prompt, tokenizer, device):
98
+ """
99
+ Prepare tokenized inputs for model generation.
100
+
101
+ Args:
102
+ prompt (str): The formatted prompt
103
+ tokenizer: The model tokenizer
104
+ device: The device to place tensors on
105
+
106
+ Returns:
107
+ dict: Tokenized inputs ready for model generation
108
+ """
109
+ # Tokenize the prompt
110
+ tokenized_inputs = tokenizer(prompt, return_tensors="pt")
111
+
112
+ # Move tensors to the correct device
113
+ inputs_on_device = {k: v.to(device) for k, v in tokenized_inputs.items()}
114
+
115
+ return inputs_on_device
116
+
117
+
118
+ def postprocess_response(response):
119
+ """
120
+ Post-process the model's response.
121
+
122
+ Args:
123
+ response (str): The raw model response
124
 
125
+ Returns:
126
+ str: The processed response
127
+ """
128
+ # Currently just returns the response as-is
129
+ # This function can be expanded for additional post-processing steps
130
+ return response