Sachi Wagaarachchi
commited on
Commit
·
0b7ba67
1
Parent(s):
191c0de
debug: updated the pipeline
Browse files- src/app.py +64 -23
- src/chat_logic.py +98 -19
- src/models.py +45 -13
- src/utils.py +115 -17
src/app.py
CHANGED
@@ -5,15 +5,15 @@ from src.vector_db import VectorDBHandler
|
|
5 |
import logging
|
6 |
import spaces
|
7 |
|
|
|
|
|
|
|
|
|
8 |
# Initialize components
|
9 |
model_manager = ModelManager()
|
10 |
vector_db = VectorDBHandler()
|
11 |
chat_processor = ChatProcessor(model_manager, vector_db)
|
12 |
|
13 |
-
# Configure logging
|
14 |
-
logging.basicConfig(level=logging.INFO)
|
15 |
-
logger = logging.getLogger(__name__)
|
16 |
-
|
17 |
@spaces.GPU
|
18 |
def respond(
|
19 |
message,
|
@@ -24,28 +24,68 @@ def respond(
|
|
24 |
temperature: float = 0.7,
|
25 |
top_p: float = 0.9,
|
26 |
top_k: int = 50,
|
27 |
-
repetition_penalty: float = 1.2
|
|
|
28 |
):
|
29 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
try:
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
|
|
|
|
|
|
47 |
yield response
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
except Exception as e:
|
50 |
logger.error(f"Chat response error: {str(e)}")
|
51 |
yield f"Error: {str(e)}"
|
@@ -65,7 +105,8 @@ demo = gr.ChatInterface(
|
|
65 |
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
|
66 |
gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p"),
|
67 |
gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top-k"),
|
68 |
-
gr.Slider(minimum=1.0, maximum=2.0, value=1.2, step=0.1, label="Repetition penalty")
|
|
|
69 |
],
|
70 |
)
|
71 |
|
|
|
5 |
import logging
|
6 |
import spaces
|
7 |
|
8 |
+
# Configure logging
|
9 |
+
logging.basicConfig(level=logging.INFO)
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
# Initialize components
|
13 |
model_manager = ModelManager()
|
14 |
vector_db = VectorDBHandler()
|
15 |
chat_processor = ChatProcessor(model_manager, vector_db)
|
16 |
|
|
|
|
|
|
|
|
|
17 |
@spaces.GPU
|
18 |
def respond(
|
19 |
message,
|
|
|
24 |
temperature: float = 0.7,
|
25 |
top_p: float = 0.9,
|
26 |
top_k: int = 50,
|
27 |
+
repetition_penalty: float = 1.2,
|
28 |
+
use_direct_pipeline: bool = False
|
29 |
):
|
30 |
+
"""
|
31 |
+
Process chat using the ChatProcessor with streaming support.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
message: The user message
|
35 |
+
history: Chat history as list of (user, assistant) message pairs
|
36 |
+
model_name: Name of the model to use
|
37 |
+
system_message: System prompt to guide the model's behavior
|
38 |
+
max_new_tokens: Maximum number of tokens to generate
|
39 |
+
temperature: Sampling temperature
|
40 |
+
top_p: Nucleus sampling parameter
|
41 |
+
top_k: Top-k sampling parameter
|
42 |
+
repetition_penalty: Penalty for token repetition
|
43 |
+
use_direct_pipeline: Whether to use the direct pipeline method
|
44 |
+
|
45 |
+
Yields:
|
46 |
+
Generated response tokens for streaming UI
|
47 |
+
"""
|
48 |
try:
|
49 |
+
if use_direct_pipeline:
|
50 |
+
# Use the direct pipeline method (non-streaming)
|
51 |
+
generation_config = {
|
52 |
+
"max_new_tokens": max_new_tokens,
|
53 |
+
"temperature": temperature,
|
54 |
+
"top_p": top_p,
|
55 |
+
"top_k": top_k,
|
56 |
+
"repetition_penalty": repetition_penalty,
|
57 |
+
"do_sample": True
|
58 |
+
}
|
59 |
+
|
60 |
+
response = chat_processor.generate_with_pipeline(
|
61 |
+
message=message,
|
62 |
+
history=history,
|
63 |
+
model_name=model_name,
|
64 |
+
generation_config=generation_config,
|
65 |
+
system_prompt=system_message
|
66 |
+
)
|
67 |
+
|
68 |
yield response
|
69 |
+
else:
|
70 |
+
# Use the streaming method
|
71 |
+
response_generator = chat_processor.process_chat(
|
72 |
+
message=message,
|
73 |
+
history=history,
|
74 |
+
model_name=model_name,
|
75 |
+
temperature=temperature,
|
76 |
+
max_new_tokens=max_new_tokens,
|
77 |
+
top_p=top_p,
|
78 |
+
top_k=top_k,
|
79 |
+
repetition_penalty=repetition_penalty,
|
80 |
+
system_prompt=system_message
|
81 |
+
)
|
82 |
|
83 |
+
# Stream response tokens
|
84 |
+
response = ""
|
85 |
+
for token in response_generator:
|
86 |
+
response += token
|
87 |
+
yield response
|
88 |
+
|
89 |
except Exception as e:
|
90 |
logger.error(f"Chat response error: {str(e)}")
|
91 |
yield f"Error: {str(e)}"
|
|
|
105 |
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
|
106 |
gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p"),
|
107 |
gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top-k"),
|
108 |
+
gr.Slider(minimum=1.0, maximum=2.0, value=1.2, step=0.1, label="Repetition penalty"),
|
109 |
+
gr.Checkbox(value=False, label="Use direct pipeline (non-streaming)")
|
110 |
],
|
111 |
)
|
112 |
|
src/chat_logic.py
CHANGED
@@ -1,6 +1,11 @@
|
|
1 |
from transformers import TextIteratorStreamer
|
2 |
import threading
|
3 |
-
from src.utils import
|
|
|
|
|
|
|
|
|
|
|
4 |
import logging
|
5 |
|
6 |
class ChatProcessor:
|
@@ -9,35 +14,56 @@ class ChatProcessor:
|
|
9 |
self.model_manager = model_manager
|
10 |
self.vector_db = vector_db
|
11 |
self.logger = logging.getLogger(__name__)
|
12 |
-
|
13 |
def process_chat(self, message, history, model_name, temperature=0.7,
|
14 |
max_new_tokens=512, top_p=0.9, top_k=50, repetition_penalty=1.2,
|
15 |
system_prompt=""):
|
16 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
try:
|
|
|
18 |
# Get model pipeline
|
19 |
pipe = self.model_manager.get_pipeline(model_name)
|
20 |
|
21 |
# Format prompt with history and tokenizer
|
22 |
prompt = format_prompt(message, history, pipe.tokenizer, system_prompt)
|
23 |
|
24 |
-
# Set up streamer
|
25 |
streamer = TextIteratorStreamer(
|
26 |
pipe.tokenizer,
|
27 |
skip_prompt=True,
|
28 |
skip_special_tokens=True
|
29 |
)
|
30 |
-
|
31 |
-
# Get full tokenizer output
|
32 |
-
tokenized_inputs = pipe.tokenizer(prompt, return_tensors="pt")
|
33 |
-
|
34 |
-
# Determine model device
|
35 |
-
device = pipe.model.device
|
36 |
|
37 |
-
#
|
38 |
-
inputs_on_device =
|
|
|
|
|
|
|
|
|
39 |
|
40 |
-
#
|
|
|
41 |
generate_kwargs = {
|
42 |
"input_ids": inputs_on_device["input_ids"],
|
43 |
"attention_mask": inputs_on_device["attention_mask"],
|
@@ -49,19 +75,72 @@ class ChatProcessor:
|
|
49 |
"streamer": streamer
|
50 |
}
|
51 |
|
52 |
-
# Start generation thread
|
53 |
thread = threading.Thread(target=pipe.model.generate, kwargs=generate_kwargs)
|
54 |
thread.start()
|
55 |
|
56 |
-
#
|
|
|
57 |
response = ""
|
58 |
for token in streamer:
|
|
|
59 |
response += token
|
|
|
60 |
yield token
|
61 |
-
|
62 |
-
#
|
63 |
-
return response
|
64 |
|
65 |
except Exception as e:
|
66 |
self.logger.error(f"Chat processing error: {str(e)}")
|
67 |
-
yield f"Error: {str(e)}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from transformers import TextIteratorStreamer
|
2 |
import threading
|
3 |
+
from src.utils import (
|
4 |
+
preprocess_chat_input,
|
5 |
+
format_prompt,
|
6 |
+
prepare_generation_inputs,
|
7 |
+
postprocess_response
|
8 |
+
)
|
9 |
import logging
|
10 |
|
11 |
class ChatProcessor:
|
|
|
14 |
self.model_manager = model_manager
|
15 |
self.vector_db = vector_db
|
16 |
self.logger = logging.getLogger(__name__)
|
17 |
+
|
18 |
def process_chat(self, message, history, model_name, temperature=0.7,
|
19 |
max_new_tokens=512, top_p=0.9, top_k=50, repetition_penalty=1.2,
|
20 |
system_prompt=""):
|
21 |
+
"""
|
22 |
+
Process chat input and generate streaming response.
|
23 |
+
|
24 |
+
This method handles the complete chat processing pipeline:
|
25 |
+
1. Pre-processing: Format the input with history and system prompt
|
26 |
+
2. Model inference: Generate a response using the specified model
|
27 |
+
3. Post-processing: Stream the response tokens
|
28 |
+
|
29 |
+
Args:
|
30 |
+
message (str): The current user message
|
31 |
+
history (list): List of tuples containing (user_message, assistant_message) pairs
|
32 |
+
model_name (str): Name of the model to use
|
33 |
+
temperature (float): Sampling temperature
|
34 |
+
max_new_tokens (int): Maximum number of tokens to generate
|
35 |
+
top_p (float): Nucleus sampling parameter
|
36 |
+
top_k (int): Top-k sampling parameter
|
37 |
+
repetition_penalty (float): Penalty for token repetition
|
38 |
+
system_prompt (str): Optional system prompt to guide the model's behavior
|
39 |
+
|
40 |
+
Yields:
|
41 |
+
str: Response tokens as they are generated
|
42 |
+
"""
|
43 |
try:
|
44 |
+
# 1. PRE-PROCESSING
|
45 |
# Get model pipeline
|
46 |
pipe = self.model_manager.get_pipeline(model_name)
|
47 |
|
48 |
# Format prompt with history and tokenizer
|
49 |
prompt = format_prompt(message, history, pipe.tokenizer, system_prompt)
|
50 |
|
51 |
+
# Set up streamer for token-by-token generation
|
52 |
streamer = TextIteratorStreamer(
|
53 |
pipe.tokenizer,
|
54 |
skip_prompt=True,
|
55 |
skip_special_tokens=True
|
56 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
|
58 |
+
# Prepare tokenized inputs
|
59 |
+
inputs_on_device = prepare_generation_inputs(
|
60 |
+
prompt,
|
61 |
+
pipe.tokenizer,
|
62 |
+
pipe.model.device
|
63 |
+
)
|
64 |
|
65 |
+
# 2. MODEL INFERENCE
|
66 |
+
# Prepare generation parameters
|
67 |
generate_kwargs = {
|
68 |
"input_ids": inputs_on_device["input_ids"],
|
69 |
"attention_mask": inputs_on_device["attention_mask"],
|
|
|
75 |
"streamer": streamer
|
76 |
}
|
77 |
|
78 |
+
# Start generation in a separate thread
|
79 |
thread = threading.Thread(target=pipe.model.generate, kwargs=generate_kwargs)
|
80 |
thread.start()
|
81 |
|
82 |
+
# 3. POST-PROCESSING
|
83 |
+
# Stream response tokens
|
84 |
response = ""
|
85 |
for token in streamer:
|
86 |
+
# Accumulate tokens for the complete response
|
87 |
response += token
|
88 |
+
# Yield each token for streaming UI
|
89 |
yield token
|
90 |
+
|
91 |
+
# Return the complete response
|
92 |
+
return postprocess_response(response)
|
93 |
|
94 |
except Exception as e:
|
95 |
self.logger.error(f"Chat processing error: {str(e)}")
|
96 |
+
yield f"Error: {str(e)}"
|
97 |
+
|
98 |
+
def generate_with_pipeline(self, message, history, model_name, generation_config=None, system_prompt=""):
|
99 |
+
"""
|
100 |
+
Alternative method that uses the Hugging Face pipeline directly.
|
101 |
+
|
102 |
+
This method demonstrates a more direct use of the pipeline API.
|
103 |
+
|
104 |
+
Args:
|
105 |
+
message (str): The current user message
|
106 |
+
history (list): List of tuples containing (user_message, assistant_message) pairs
|
107 |
+
model_name (str): Name of the model to use
|
108 |
+
generation_config (dict): Configuration for text generation
|
109 |
+
system_prompt (str): Optional system prompt to guide the model's behavior
|
110 |
+
|
111 |
+
Returns:
|
112 |
+
str: The generated response
|
113 |
+
"""
|
114 |
+
try:
|
115 |
+
# Get model pipeline
|
116 |
+
pipe = self.model_manager.get_pipeline(model_name)
|
117 |
+
|
118 |
+
# Pre-process: Format messages for the pipeline
|
119 |
+
messages = preprocess_chat_input(message, history, system_prompt)
|
120 |
+
|
121 |
+
# Set default generation config if not provided
|
122 |
+
if generation_config is None:
|
123 |
+
generation_config = {
|
124 |
+
"max_new_tokens": 512,
|
125 |
+
"temperature": 0.7,
|
126 |
+
"top_p": 0.9,
|
127 |
+
"top_k": 50,
|
128 |
+
"repetition_penalty": 1.2,
|
129 |
+
"do_sample": True
|
130 |
+
}
|
131 |
+
|
132 |
+
# Direct pipeline inference
|
133 |
+
response = pipe(
|
134 |
+
messages,
|
135 |
+
**generation_config
|
136 |
+
)
|
137 |
+
|
138 |
+
# Post-process the response
|
139 |
+
if isinstance(response, list):
|
140 |
+
return postprocess_response(response[0]["generated_text"])
|
141 |
+
else:
|
142 |
+
return postprocess_response(response["generated_text"])
|
143 |
+
|
144 |
+
except Exception as e:
|
145 |
+
self.logger.error(f"Pipeline generation error: {str(e)}")
|
146 |
+
return f"Error: {str(e)}"
|
src/models.py
CHANGED
@@ -1,5 +1,41 @@
|
|
1 |
from transformers import pipeline
|
2 |
import logging
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
class ModelManager:
|
5 |
"""Manages loading and caching of Qwen models"""
|
@@ -8,23 +44,19 @@ class ModelManager:
|
|
8 |
"Qwen3-14B": "Qwen/Qwen3-14B",
|
9 |
"Qwen3-8B": "Qwen/Qwen3-8B"
|
10 |
}
|
11 |
-
self._pipelines = {}
|
12 |
self.logger = logging.getLogger(__name__)
|
13 |
-
|
14 |
def get_pipeline(self, model_name):
|
15 |
"""Get or create a model pipeline"""
|
16 |
-
if model_name in self._pipelines:
|
17 |
-
return self._pipelines[model_name]
|
18 |
-
|
19 |
try:
|
20 |
model_id = self.models[model_name]
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
return
|
29 |
except KeyError:
|
30 |
raise ValueError(f"Model {model_name} not found in available models")
|
|
|
1 |
from transformers import pipeline
|
2 |
import logging
|
3 |
+
from functools import lru_cache
|
4 |
+
|
5 |
+
# Global cache for pipelines to ensure they're initialized only once
|
6 |
+
_PIPELINE_CACHE = {}
|
7 |
+
|
8 |
+
@lru_cache(maxsize=5)
|
9 |
+
def get_pipeline(model_id, task="text-generation"):
|
10 |
+
"""
|
11 |
+
Get or create a model pipeline with caching.
|
12 |
+
|
13 |
+
This function is cached using lru_cache to ensure efficient reuse.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
model_id (str): The Hugging Face model ID
|
17 |
+
task (str): The pipeline task (default: "text-generation")
|
18 |
+
|
19 |
+
Returns:
|
20 |
+
The pipeline object
|
21 |
+
"""
|
22 |
+
cache_key = f"{model_id}_{task}"
|
23 |
+
|
24 |
+
if cache_key in _PIPELINE_CACHE:
|
25 |
+
return _PIPELINE_CACHE[cache_key]
|
26 |
+
|
27 |
+
logger = logging.getLogger(__name__)
|
28 |
+
logger.info(f"Loading model: {model_id} for task: {task}")
|
29 |
+
|
30 |
+
pipe = pipeline(
|
31 |
+
task,
|
32 |
+
model=model_id,
|
33 |
+
device_map="auto"
|
34 |
+
)
|
35 |
+
|
36 |
+
_PIPELINE_CACHE[cache_key] = pipe
|
37 |
+
return pipe
|
38 |
+
|
39 |
|
40 |
class ModelManager:
|
41 |
"""Manages loading and caching of Qwen models"""
|
|
|
44 |
"Qwen3-14B": "Qwen/Qwen3-14B",
|
45 |
"Qwen3-8B": "Qwen/Qwen3-8B"
|
46 |
}
|
|
|
47 |
self.logger = logging.getLogger(__name__)
|
48 |
+
|
49 |
def get_pipeline(self, model_name):
|
50 |
"""Get or create a model pipeline"""
|
|
|
|
|
|
|
51 |
try:
|
52 |
model_id = self.models[model_name]
|
53 |
+
return get_pipeline(model_id)
|
54 |
+
except KeyError:
|
55 |
+
raise ValueError(f"Model {model_name} not found in available models")
|
56 |
+
|
57 |
+
def get_model_id(self, model_name):
|
58 |
+
"""Get the model ID for a given model name"""
|
59 |
+
try:
|
60 |
+
return self.models[model_name]
|
61 |
except KeyError:
|
62 |
raise ValueError(f"Model {model_name} not found in available models")
|
src/utils.py
CHANGED
@@ -1,7 +1,18 @@
|
|
1 |
-
|
2 |
-
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
"""
|
6 |
# Convert history from tuples to dict format expected by apply_chat_template
|
7 |
formatted_history = []
|
@@ -12,21 +23,108 @@ def format_prompt(message, history, tokenizer, system_prompt=""):
|
|
12 |
# Add current message
|
13 |
formatted_history.append({"role": "user", "content": message})
|
14 |
|
15 |
-
|
|
|
16 |
messages = [{"role": "system", "content": system_prompt.strip()}] + formatted_history
|
17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
else:
|
19 |
# Fallback for base LMs without chat template
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
-
|
31 |
-
|
32 |
-
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Utility functions for pre-processing and post-processing in the chat application.
|
3 |
+
"""
|
4 |
+
|
5 |
+
def preprocess_chat_input(message, history, system_prompt=""):
|
6 |
+
"""
|
7 |
+
Pre-process chat input to prepare it for the model.
|
8 |
|
9 |
+
Args:
|
10 |
+
message (str): The current user message
|
11 |
+
history (list): List of tuples containing (user_message, assistant_message) pairs
|
12 |
+
system_prompt (str): Optional system prompt to guide the model's behavior
|
13 |
+
|
14 |
+
Returns:
|
15 |
+
dict: Formatted messages in the format expected by the tokenizer
|
16 |
"""
|
17 |
# Convert history from tuples to dict format expected by apply_chat_template
|
18 |
formatted_history = []
|
|
|
23 |
# Add current message
|
24 |
formatted_history.append({"role": "user", "content": message})
|
25 |
|
26 |
+
# Add system message if provided
|
27 |
+
if system_prompt.strip():
|
28 |
messages = [{"role": "system", "content": system_prompt.strip()}] + formatted_history
|
29 |
+
else:
|
30 |
+
messages = formatted_history
|
31 |
+
|
32 |
+
return messages
|
33 |
+
|
34 |
+
|
35 |
+
def format_prompt(message, history, tokenizer, system_prompt=""):
|
36 |
+
"""
|
37 |
+
Format message and history into a prompt for Qwen models.
|
38 |
+
|
39 |
+
Uses tokenizer.apply_chat_template if available, otherwise falls back to manual formatting.
|
40 |
+
|
41 |
+
Args:
|
42 |
+
message (str): The current user message
|
43 |
+
history (list): List of tuples containing (user_message, assistant_message) pairs
|
44 |
+
tokenizer: The model tokenizer
|
45 |
+
system_prompt (str): Optional system prompt to guide the model's behavior
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
str: Formatted prompt ready for the model
|
49 |
+
"""
|
50 |
+
# Get pre-processed messages
|
51 |
+
messages = preprocess_chat_input(message, history, system_prompt)
|
52 |
+
|
53 |
+
# Apply chat template if available
|
54 |
+
if hasattr(tokenizer, "chat_template") and tokenizer.chat_template:
|
55 |
+
return tokenizer.apply_chat_template(
|
56 |
+
messages,
|
57 |
+
tokenize=False,
|
58 |
+
add_generation_prompt=True,
|
59 |
+
enable_thinking=True
|
60 |
+
)
|
61 |
else:
|
62 |
# Fallback for base LMs without chat template
|
63 |
+
return format_prompt_fallback(messages)
|
64 |
+
|
65 |
+
|
66 |
+
def format_prompt_fallback(messages):
|
67 |
+
"""
|
68 |
+
Fallback prompt formatting for models without chat templates.
|
69 |
+
|
70 |
+
Args:
|
71 |
+
messages (list): List of message dictionaries with role and content
|
72 |
+
|
73 |
+
Returns:
|
74 |
+
str: Formatted prompt string
|
75 |
+
"""
|
76 |
+
prompt = ""
|
77 |
+
|
78 |
+
# Add system message if present
|
79 |
+
if messages and messages[0]['role'] == 'system':
|
80 |
+
prompt = messages[0]['content'].strip() + "\n"
|
81 |
+
messages = messages[1:]
|
82 |
+
|
83 |
+
# Add conversation history
|
84 |
+
for msg in messages:
|
85 |
+
if msg['role'] == 'user':
|
86 |
+
prompt += f"<|User|>: {msg['content'].strip()}\n"
|
87 |
+
elif msg['role'] == 'assistant':
|
88 |
+
prompt += f"<|Assistant|>: {msg['content'].strip()}\n"
|
89 |
+
|
90 |
+
# Add final assistant prompt if needed
|
91 |
+
if not prompt.strip().endswith("<|Assistant|>:"):
|
92 |
+
prompt += "<|Assistant|>:"
|
93 |
+
|
94 |
+
return prompt
|
95 |
+
|
96 |
+
|
97 |
+
def prepare_generation_inputs(prompt, tokenizer, device):
|
98 |
+
"""
|
99 |
+
Prepare tokenized inputs for model generation.
|
100 |
+
|
101 |
+
Args:
|
102 |
+
prompt (str): The formatted prompt
|
103 |
+
tokenizer: The model tokenizer
|
104 |
+
device: The device to place tensors on
|
105 |
+
|
106 |
+
Returns:
|
107 |
+
dict: Tokenized inputs ready for model generation
|
108 |
+
"""
|
109 |
+
# Tokenize the prompt
|
110 |
+
tokenized_inputs = tokenizer(prompt, return_tensors="pt")
|
111 |
+
|
112 |
+
# Move tensors to the correct device
|
113 |
+
inputs_on_device = {k: v.to(device) for k, v in tokenized_inputs.items()}
|
114 |
+
|
115 |
+
return inputs_on_device
|
116 |
+
|
117 |
+
|
118 |
+
def postprocess_response(response):
|
119 |
+
"""
|
120 |
+
Post-process the model's response.
|
121 |
+
|
122 |
+
Args:
|
123 |
+
response (str): The raw model response
|
124 |
|
125 |
+
Returns:
|
126 |
+
str: The processed response
|
127 |
+
"""
|
128 |
+
# Currently just returns the response as-is
|
129 |
+
# This function can be expanded for additional post-processing steps
|
130 |
+
return response
|