Spaces:
Build error
Build error
Add serve.py for model deployment and API integration, update requirements.txt for smolagents with vllm support, and enhance .gitignore to exclude memory snapshot files. Additionally, implement testing configuration in config.yaml and modify train.py for memory tracking and model saving in VLLM format.
Browse files- .gitignore +1 -0
- conf/config.yaml +9 -0
- requirements.txt +1 -1
- serve.py +365 -0
- train.py +149 -7
.gitignore
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
.env
|
2 |
logs
|
3 |
lora_model
|
|
|
4 |
outputs
|
5 |
__pycache__
|
6 |
.pytest_cache
|
|
|
1 |
.env
|
2 |
logs
|
3 |
lora_model
|
4 |
+
memory_snapshot.pickle
|
5 |
outputs
|
6 |
__pycache__
|
7 |
.pytest_cache
|
conf/config.yaml
CHANGED
@@ -70,3 +70,12 @@ output:
|
|
70 |
|
71 |
# Training control
|
72 |
train: false
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
|
71 |
# Training control
|
72 |
train: false
|
73 |
+
|
74 |
+
# Testing configuration
|
75 |
+
test: true # Whether to run testing after training
|
76 |
+
test_dataset:
|
77 |
+
name: "gaia-benchmark/GAIA"
|
78 |
+
config: "2023_level1" # Use level 1 questions for testing
|
79 |
+
split: "test" # Use test split for testing
|
80 |
+
max_samples: 10 # Number of samples to test on
|
81 |
+
max_length: 2048 # Maximum sequence length for testing
|
requirements.txt
CHANGED
@@ -27,7 +27,7 @@ pytest-cov>=6.1.1
|
|
27 |
python-dotenv>=1.0.0
|
28 |
requests>=2.32.3
|
29 |
sentence-transformers>=4.1.0
|
30 |
-
smolagents[litellm,telemetry]>=1.14.0
|
31 |
tensorboardX>=2.6.2.2
|
32 |
trl>=0.17.0
|
33 |
typing-extensions>=4.5.0
|
|
|
27 |
python-dotenv>=1.0.0
|
28 |
requests>=2.32.3
|
29 |
sentence-transformers>=4.1.0
|
30 |
+
smolagents[litellm,telemetry,vllm]>=1.14.0
|
31 |
tensorboardX>=2.6.2.2
|
32 |
trl>=0.17.0
|
33 |
typing-extensions>=4.5.0
|
serve.py
ADDED
@@ -0,0 +1,365 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
import time
|
5 |
+
from pprint import pprint
|
6 |
+
from threading import Thread
|
7 |
+
from typing import Any, Dict, List
|
8 |
+
|
9 |
+
from fastapi import FastAPI, Request
|
10 |
+
from openai.types.chat.chat_completion import ChatCompletion
|
11 |
+
from openai.types.chat.chat_completion import Choice as ChatCompletionChoice
|
12 |
+
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
13 |
+
from openai.types.chat.chat_completion_chunk import Choice as ChatCompletionChunkChoice
|
14 |
+
from openai.types.chat.chat_completion_chunk import ChoiceDelta
|
15 |
+
from openai.types.chat.chat_completion_message import ChatCompletionMessage
|
16 |
+
from openai.types.chat.completion_create_params import CompletionCreateParams
|
17 |
+
from pydantic import TypeAdapter
|
18 |
+
from ray import serve
|
19 |
+
from sse_starlette import EventSourceResponse
|
20 |
+
from starlette.responses import JSONResponse
|
21 |
+
from transformers.generation.streamers import AsyncTextIteratorStreamer
|
22 |
+
from transformers.image_utils import load_image
|
23 |
+
from unsloth import FastVisionModel
|
24 |
+
|
25 |
+
dtype = (
|
26 |
+
None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
|
27 |
+
)
|
28 |
+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.
|
29 |
+
max_seq_length = 2048 # Supports RoPE Scaling interally, so choose any!
|
30 |
+
# max_seq_length = 4096 # Choose any! We auto support RoPE Scaling internally!
|
31 |
+
|
32 |
+
|
33 |
+
logger = logging.getLogger("ray.serve")
|
34 |
+
|
35 |
+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
36 |
+
|
37 |
+
app = FastAPI()
|
38 |
+
|
39 |
+
# middlewares = [
|
40 |
+
# middleware
|
41 |
+
# for middleware in ConnexionMiddleware.default_middlewares
|
42 |
+
# if middleware is not SecurityMiddleware
|
43 |
+
# ]
|
44 |
+
|
45 |
+
# connexion_app = AsyncApp(import_name=__name__, middlewares=middlewares)
|
46 |
+
|
47 |
+
# connexion_app.add_api(
|
48 |
+
# # "api/openai/v1/openapi/openapi.yaml",
|
49 |
+
# "api/v1/openapi/openapi.yaml",
|
50 |
+
# # base_path="/openai/v1",
|
51 |
+
# base_path="/v1",
|
52 |
+
# pythonic_params=True,
|
53 |
+
# resolver_error=501,
|
54 |
+
# )
|
55 |
+
|
56 |
+
# # fastapi_app.mount("/api", ConnexionMiddleware(app=connexion_app, import_name=__name__))
|
57 |
+
# # app.mount("/api", ConnexionMiddleware(app=connexion_app, import_name=__name__))
|
58 |
+
# app.mount(
|
59 |
+
# "/",
|
60 |
+
# ConnexionMiddleware(
|
61 |
+
# app=connexion_app,
|
62 |
+
# import_name=__name__,
|
63 |
+
# # middlewares=middlewares,
|
64 |
+
# ),
|
65 |
+
# )
|
66 |
+
|
67 |
+
|
68 |
+
@serve.deployment(
|
69 |
+
autoscaling_config={
|
70 |
+
# https://docs.ray.io/en/latest/serve/advanced-guides/advanced-autoscaling.html#required-define-upper-and-lower-autoscaling-limits
|
71 |
+
"max_replicas": 1,
|
72 |
+
"min_replicas": 1, # TOOD: set to 0
|
73 |
+
"target_ongoing_requests": 2, # https://docs.ray.io/en/latest/serve/advanced-guides/advanced-autoscaling.html#target-ongoing-requests-default-2
|
74 |
+
},
|
75 |
+
max_ongoing_requests=5, # https://docs.ray.io/en/latest/serve/advanced-guides/advanced-autoscaling.html#max-ongoing-requests-default-5
|
76 |
+
ray_actor_options={"num_gpus": 1},
|
77 |
+
)
|
78 |
+
@serve.ingress(app)
|
79 |
+
class ModelDeployment:
|
80 |
+
def __init__(
|
81 |
+
self,
|
82 |
+
model_name: str,
|
83 |
+
):
|
84 |
+
self.model_name = model_name
|
85 |
+
|
86 |
+
model, processor = FastVisionModel.from_pretrained(
|
87 |
+
load_in_4bit=load_in_4bit,
|
88 |
+
max_seq_length=max_seq_length,
|
89 |
+
model_name=self.model_name,
|
90 |
+
)
|
91 |
+
|
92 |
+
with open("chat_template.txt", "r") as f:
|
93 |
+
processor.chat_template = f.read()
|
94 |
+
processor.tokenizer.chat_template = processor.chat_template
|
95 |
+
|
96 |
+
FastVisionModel.for_inference(model) # Enable native 2x faster inference
|
97 |
+
|
98 |
+
self.model = model
|
99 |
+
self.processor = processor
|
100 |
+
|
101 |
+
def reconfigure(self, config: Dict[str, Any]):
|
102 |
+
print("=== reconfigure ===")
|
103 |
+
print("config:")
|
104 |
+
print(config)
|
105 |
+
# https://docs.ray.io/en/latest/serve/production-guide/config.html#dynamically-change-parameters-without-restarting-replicas-user-config
|
106 |
+
|
107 |
+
@app.post("/v1/chat/completions")
|
108 |
+
async def create_chat_completion(self, body: dict, raw_request: Request):
|
109 |
+
"""Creates a model response for the given chat conversation. Learn more in the [text generation](/docs/guides/text-generation), [vision](/docs/guides/vision), and [audio](/docs/guides/audio) guides. Parameter support can differ depending on the model used to generate the response, particularly for newer reasoning models. Parameters that are only supported for reasoning models are noted below. For the current state of unsupported parameters in reasoning models, [refer to the reasoning guide](/docs/guides/reasoning).
|
110 |
+
|
111 |
+
# noqa: E501
|
112 |
+
|
113 |
+
:param create_chat_completion_request:
|
114 |
+
:type create_chat_completion_request: dict | bytes
|
115 |
+
|
116 |
+
:rtype: Union[CreateChatCompletionResponse, Tuple[CreateChatCompletionResponse, int], Tuple[CreateChatCompletionResponse, int, Dict[str, str]]
|
117 |
+
"""
|
118 |
+
print("=== create_chat_completion ===")
|
119 |
+
|
120 |
+
print("body:")
|
121 |
+
pprint(body)
|
122 |
+
|
123 |
+
ta = TypeAdapter(CompletionCreateParams)
|
124 |
+
|
125 |
+
print("ta.validate_python...")
|
126 |
+
pprint(ta.validate_python(body))
|
127 |
+
|
128 |
+
max_new_tokens = body.get("max_completion_tokens", body.get("max_tokens"))
|
129 |
+
messages = body.get("messages")
|
130 |
+
model_name = body.get("model")
|
131 |
+
stream = body.get("stream", False)
|
132 |
+
temperature = body.get("temperature")
|
133 |
+
tools = body.get("tools")
|
134 |
+
|
135 |
+
images = []
|
136 |
+
|
137 |
+
for message in messages:
|
138 |
+
for content in message["content"]:
|
139 |
+
if "type" in content and content["type"] == "image_url":
|
140 |
+
image_url = content["image_url"]["url"]
|
141 |
+
image = load_image(image_url)
|
142 |
+
images.append(image)
|
143 |
+
|
144 |
+
content["type"] = "image"
|
145 |
+
del content["image_url"]
|
146 |
+
|
147 |
+
images = images if images else None
|
148 |
+
|
149 |
+
if model_name != self.model_name:
|
150 |
+
# adapter_path = model_name
|
151 |
+
# self.model.load_adapter(adapter_path)
|
152 |
+
|
153 |
+
return JSONResponse(content={"error": "Model not found"}, status_code=404)
|
154 |
+
|
155 |
+
prompt = self.processor.apply_chat_template(
|
156 |
+
add_generation_prompt=True,
|
157 |
+
conversation=messages,
|
158 |
+
# documents=documents,
|
159 |
+
tools=tools,
|
160 |
+
)
|
161 |
+
|
162 |
+
print("prompt:")
|
163 |
+
print(prompt)
|
164 |
+
|
165 |
+
inputs = self.processor(text=prompt, images=images, return_tensors="pt")
|
166 |
+
inputs = inputs.to(self.model.device)
|
167 |
+
input_ids = inputs.input_ids
|
168 |
+
|
169 |
+
class GeneratorThread(Thread):
|
170 |
+
"""Thread to generate completions in the background."""
|
171 |
+
|
172 |
+
def __init__(self, model, **generation_kwargs):
|
173 |
+
super().__init__()
|
174 |
+
|
175 |
+
self.chat_completion = None
|
176 |
+
self.generation_kwargs = generation_kwargs
|
177 |
+
self.model = model
|
178 |
+
|
179 |
+
def run(self):
|
180 |
+
import torch
|
181 |
+
import torch._dynamo.config
|
182 |
+
|
183 |
+
try:
|
184 |
+
try:
|
185 |
+
self.generated_ids = self.model.generate(
|
186 |
+
**self.generation_kwargs
|
187 |
+
)
|
188 |
+
|
189 |
+
except torch._dynamo.exc.BackendCompilerFailed as e:
|
190 |
+
print(e)
|
191 |
+
print("Disabling dynamo...")
|
192 |
+
|
193 |
+
torch._dynamo.config.disable = True
|
194 |
+
|
195 |
+
self.generated_ids = self.model.generate(
|
196 |
+
**self.generation_kwargs
|
197 |
+
)
|
198 |
+
|
199 |
+
except Exception as e:
|
200 |
+
print(e)
|
201 |
+
print("Warning: Exception in GeneratorThread")
|
202 |
+
self.generated_ids = []
|
203 |
+
|
204 |
+
def join(self, timeout=None):
|
205 |
+
super().join()
|
206 |
+
|
207 |
+
return self.generated_ids
|
208 |
+
|
209 |
+
decode_kwargs = dict(skip_special_tokens=True)
|
210 |
+
|
211 |
+
streamer = (
|
212 |
+
AsyncTextIteratorStreamer(
|
213 |
+
self.processor,
|
214 |
+
skip_prompt=True,
|
215 |
+
**decode_kwargs,
|
216 |
+
)
|
217 |
+
if stream
|
218 |
+
else None
|
219 |
+
)
|
220 |
+
|
221 |
+
generation_kwargs = dict(
|
222 |
+
**inputs,
|
223 |
+
max_new_tokens=max_new_tokens,
|
224 |
+
streamer=streamer,
|
225 |
+
temperature=temperature,
|
226 |
+
use_cache=True,
|
227 |
+
)
|
228 |
+
|
229 |
+
thread = GeneratorThread(self.model, **generation_kwargs)
|
230 |
+
thread.start()
|
231 |
+
|
232 |
+
if stream:
|
233 |
+
|
234 |
+
async def event_publisher():
|
235 |
+
i = 0
|
236 |
+
|
237 |
+
try:
|
238 |
+
async for new_text in streamer:
|
239 |
+
print("new_text:")
|
240 |
+
print(new_text)
|
241 |
+
|
242 |
+
choices: List[ChatCompletionChunkChoice] = [
|
243 |
+
ChatCompletionChunkChoice(
|
244 |
+
_request_id=None,
|
245 |
+
delta=ChoiceDelta(
|
246 |
+
_request_id=None,
|
247 |
+
content=new_text,
|
248 |
+
function_call=None,
|
249 |
+
refusal=None,
|
250 |
+
role="assistant",
|
251 |
+
tool_calls=None,
|
252 |
+
),
|
253 |
+
finish_reason=None,
|
254 |
+
index=0,
|
255 |
+
logprobs=None,
|
256 |
+
)
|
257 |
+
]
|
258 |
+
|
259 |
+
chat_completion_chunk = ChatCompletionChunk(
|
260 |
+
_request_id=None,
|
261 |
+
choices=choices,
|
262 |
+
created=int(time.time()),
|
263 |
+
id=str(i),
|
264 |
+
model=model_name,
|
265 |
+
object="chat.completion.chunk",
|
266 |
+
service_tier=None,
|
267 |
+
system_fingerprint=None,
|
268 |
+
usage=None,
|
269 |
+
)
|
270 |
+
|
271 |
+
yield chat_completion_chunk.model_dump_json()
|
272 |
+
|
273 |
+
i += 1
|
274 |
+
|
275 |
+
except asyncio.CancelledError as e:
|
276 |
+
print("Disconnected from client (via refresh/close)")
|
277 |
+
raise e
|
278 |
+
|
279 |
+
except Exception as e:
|
280 |
+
print(f"Exception: {e}")
|
281 |
+
raise e
|
282 |
+
|
283 |
+
return EventSourceResponse(event_publisher())
|
284 |
+
|
285 |
+
generated_ids = thread.join()
|
286 |
+
input_length = input_ids.shape[1]
|
287 |
+
|
288 |
+
batch_decoded_outputs = self.processor.batch_decode(
|
289 |
+
generated_ids[:, input_length:],
|
290 |
+
skip_special_tokens=True,
|
291 |
+
)
|
292 |
+
|
293 |
+
choices: List[ChatCompletionChoice] = []
|
294 |
+
|
295 |
+
for i, response in enumerate(batch_decoded_outputs):
|
296 |
+
print("response:")
|
297 |
+
print(response)
|
298 |
+
|
299 |
+
# try:
|
300 |
+
# response = json.loads(response)
|
301 |
+
|
302 |
+
# finish_reason: str = response.get("finish_reason")
|
303 |
+
# tool_calls_json = response.get("tool_calls")
|
304 |
+
# tool_calls: List[ToolCall] = []
|
305 |
+
|
306 |
+
# for tool_call_json in tool_calls_json:
|
307 |
+
# tool_call = ToolCall(
|
308 |
+
# function=FunctionToolCallArguments(
|
309 |
+
# arguments=tool_call_json.get("arguments"),
|
310 |
+
# name=tool_call_json.get("name"),
|
311 |
+
# ),
|
312 |
+
# id=tool_call_json.get("id"),
|
313 |
+
# type="function",
|
314 |
+
# )
|
315 |
+
|
316 |
+
# tool_calls.append(tool_call)
|
317 |
+
|
318 |
+
# message: ChatMessage = ChatMessage(
|
319 |
+
# role="assistant",
|
320 |
+
# tool_calls=tool_calls,
|
321 |
+
# )
|
322 |
+
|
323 |
+
# except json.JSONDecodeError:
|
324 |
+
# finish_reason: str = "stop"
|
325 |
+
# message: ChatMessage = ChatMessage(
|
326 |
+
# role="assistant",
|
327 |
+
# content=response,
|
328 |
+
# )
|
329 |
+
|
330 |
+
message = ChatCompletionMessage(
|
331 |
+
audio=None,
|
332 |
+
content=response,
|
333 |
+
refusal=None,
|
334 |
+
role="assistant",
|
335 |
+
tool_calls=None,
|
336 |
+
)
|
337 |
+
|
338 |
+
choices.append(
|
339 |
+
ChatCompletionChoice(
|
340 |
+
index=i,
|
341 |
+
finish_reason="stop",
|
342 |
+
logprobs=None,
|
343 |
+
message=message,
|
344 |
+
)
|
345 |
+
)
|
346 |
+
|
347 |
+
chat_completion = ChatCompletion(
|
348 |
+
choices=choices,
|
349 |
+
created=int(time.time()),
|
350 |
+
id="1",
|
351 |
+
model=model_name,
|
352 |
+
object="chat.completion",
|
353 |
+
service_tier=None,
|
354 |
+
system_fingerprint=None,
|
355 |
+
usage=None,
|
356 |
+
)
|
357 |
+
|
358 |
+
return chat_completion.model_dump(mode="json")
|
359 |
+
|
360 |
+
|
361 |
+
def build_app(cli_args: Dict[str, str]) -> serve.Application:
|
362 |
+
"""Builds the Serve app based on CLI arguments."""
|
363 |
+
return ModelDeployment.options().bind(
|
364 |
+
cli_args.get("model_name"),
|
365 |
+
)
|
train.py
CHANGED
@@ -40,8 +40,16 @@ from transformers import (
|
|
40 |
DataCollatorForLanguageModeling,
|
41 |
Trainer,
|
42 |
TrainingArguments,
|
|
|
43 |
)
|
44 |
from trl import SFTTrainer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
# Setup logging
|
47 |
def setup_logging():
|
@@ -130,7 +138,9 @@ def load_and_format_dataset(
|
|
130 |
logger.info(f"Dataset loaded successfully. Size: {len(dataset)} examples")
|
131 |
|
132 |
# Split into train and validation sets
|
133 |
-
dataset = dataset.train_test_split(
|
|
|
|
|
134 |
logger.info(
|
135 |
f"Dataset split into train ({len(dataset['train'])} examples) and validation ({len(dataset['test'])} examples) sets"
|
136 |
)
|
@@ -188,10 +198,12 @@ def create_trainer(
|
|
188 |
# Create TrainingArguments from config
|
189 |
training_args_dict = OmegaConf.to_container(cfg.training.args, resolve=True)
|
190 |
# Add dynamic precision settings
|
191 |
-
training_args_dict.update(
|
192 |
-
|
193 |
-
|
194 |
-
|
|
|
|
|
195 |
training_args = TrainingArguments(**training_args_dict)
|
196 |
|
197 |
# Create data collator from config
|
@@ -202,7 +214,7 @@ def create_trainer(
|
|
202 |
|
203 |
# Create SFT config without data_collator to avoid duplication
|
204 |
sft_config = OmegaConf.to_container(cfg.training.sft, resolve=True)
|
205 |
-
sft_config.pop(
|
206 |
|
207 |
trainer = SFTTrainer(
|
208 |
model=model,
|
@@ -247,15 +259,145 @@ def main(cfg: DictConfig) -> None:
|
|
247 |
# Save model
|
248 |
logger.info(f"Saving final model to {cfg.output.dir}...")
|
249 |
trainer.save_model(cfg.output.dir)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
250 |
|
251 |
# Print final metrics
|
252 |
final_metrics = trainer.state.log_history[-1]
|
253 |
logger.info("\nTraining completed!")
|
254 |
logger.info(f"Final training loss: {final_metrics.get('loss', 'N/A')}")
|
255 |
-
logger.info(
|
|
|
|
|
256 |
else:
|
257 |
logger.info("Training skipped as train=False")
|
258 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
259 |
except Exception as e:
|
260 |
logger.error(f"Error in main training process: {e}")
|
261 |
raise
|
|
|
40 |
DataCollatorForLanguageModeling,
|
41 |
Trainer,
|
42 |
TrainingArguments,
|
43 |
+
AutoModelForCausalLM,
|
44 |
)
|
45 |
from trl import SFTTrainer
|
46 |
+
from peft import PeftModel
|
47 |
+
from smolagents import CodeAgent, Model, TransformersModel, VLLMModel
|
48 |
+
from tools.smart_search.tool import SmartSearchTool
|
49 |
+
from smolagents.monitoring import LogLevel
|
50 |
+
import torch
|
51 |
+
import os
|
52 |
+
|
53 |
|
54 |
# Setup logging
|
55 |
def setup_logging():
|
|
|
138 |
logger.info(f"Dataset loaded successfully. Size: {len(dataset)} examples")
|
139 |
|
140 |
# Split into train and validation sets
|
141 |
+
dataset = dataset.train_test_split(
|
142 |
+
test_size=cfg.dataset.validation_split, seed=cfg.dataset.seed
|
143 |
+
)
|
144 |
logger.info(
|
145 |
f"Dataset split into train ({len(dataset['train'])} examples) and validation ({len(dataset['test'])} examples) sets"
|
146 |
)
|
|
|
198 |
# Create TrainingArguments from config
|
199 |
training_args_dict = OmegaConf.to_container(cfg.training.args, resolve=True)
|
200 |
# Add dynamic precision settings
|
201 |
+
training_args_dict.update(
|
202 |
+
{
|
203 |
+
"fp16": not is_bfloat16_supported(),
|
204 |
+
"bf16": is_bfloat16_supported(),
|
205 |
+
}
|
206 |
+
)
|
207 |
training_args = TrainingArguments(**training_args_dict)
|
208 |
|
209 |
# Create data collator from config
|
|
|
214 |
|
215 |
# Create SFT config without data_collator to avoid duplication
|
216 |
sft_config = OmegaConf.to_container(cfg.training.sft, resolve=True)
|
217 |
+
sft_config.pop("data_collator", None) # Remove data_collator from config
|
218 |
|
219 |
trainer = SFTTrainer(
|
220 |
model=model,
|
|
|
259 |
# Save model
|
260 |
logger.info(f"Saving final model to {cfg.output.dir}...")
|
261 |
trainer.save_model(cfg.output.dir)
|
262 |
+
|
263 |
+
# Save model in VLLM format
|
264 |
+
logger.info("Saving model in VLLM format...")
|
265 |
+
model.save_pretrained_merged(
|
266 |
+
cfg.output.dir,
|
267 |
+
tokenizer,
|
268 |
+
save_method="merged_16bit"
|
269 |
+
)
|
270 |
|
271 |
# Print final metrics
|
272 |
final_metrics = trainer.state.log_history[-1]
|
273 |
logger.info("\nTraining completed!")
|
274 |
logger.info(f"Final training loss: {final_metrics.get('loss', 'N/A')}")
|
275 |
+
logger.info(
|
276 |
+
f"Final validation loss: {final_metrics.get('eval_loss', 'N/A')}"
|
277 |
+
)
|
278 |
else:
|
279 |
logger.info("Training skipped as train=False")
|
280 |
|
281 |
+
# Test if requested
|
282 |
+
if cfg.test:
|
283 |
+
logger.info("\nStarting testing...")
|
284 |
+
try:
|
285 |
+
# Enable memory history tracking
|
286 |
+
torch.cuda.memory._record_memory_history()
|
287 |
+
|
288 |
+
# Set memory allocation configuration
|
289 |
+
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True,max_split_size_mb:128'
|
290 |
+
|
291 |
+
# Load test dataset
|
292 |
+
test_dataset = load_dataset(
|
293 |
+
cfg.test_dataset.name,
|
294 |
+
cfg.test_dataset.config,
|
295 |
+
split=cfg.test_dataset.split,
|
296 |
+
trust_remote_code=True,
|
297 |
+
)
|
298 |
+
logger.info(f"Loaded test dataset with {len(test_dataset)} examples")
|
299 |
+
logger.info(f"Dataset features: {test_dataset.features}")
|
300 |
+
|
301 |
+
# Clear CUDA cache before loading model
|
302 |
+
torch.cuda.empty_cache()
|
303 |
+
|
304 |
+
# Initialize model
|
305 |
+
model: Model = Model(
|
306 |
+
model_id=cfg.model.name,
|
307 |
+
# model_id=cfg.output.dir,
|
308 |
+
)
|
309 |
+
|
310 |
+
# model: Model = TransformersModel(
|
311 |
+
# model_id=cfg.model.name,
|
312 |
+
# # model_id=cfg.output.dir,
|
313 |
+
# )
|
314 |
+
|
315 |
+
# model: Model = VLLMModel(
|
316 |
+
# model_id=cfg.model.name,
|
317 |
+
# # model_id=cfg.output.dir,
|
318 |
+
# )
|
319 |
+
|
320 |
+
# Create CodeAgent with SmartSearchTool
|
321 |
+
agent = CodeAgent(
|
322 |
+
model=model,
|
323 |
+
tools=[SmartSearchTool()],
|
324 |
+
verbosity_level=LogLevel.ERROR,
|
325 |
+
)
|
326 |
+
|
327 |
+
# Format task to get succinct answer
|
328 |
+
def format_task(question):
|
329 |
+
return f"""Please provide two answers to the following question:
|
330 |
+
|
331 |
+
1. A succinct answer that follows these rules:
|
332 |
+
- Contains ONLY the answer, nothing else
|
333 |
+
- Does not repeat the question
|
334 |
+
- Does not include explanations, reasoning, or context
|
335 |
+
- Does not include source attribution or references
|
336 |
+
- Does not use phrases like "The answer is" or "I found that"
|
337 |
+
- Does not include formatting, bullet points, or line breaks
|
338 |
+
- If the answer is a number, return only the number
|
339 |
+
- If the answer requires multiple items, separate them with commas
|
340 |
+
- If the answer requires ordering, maintain the specified order
|
341 |
+
- Uses the most direct and succinct form possible
|
342 |
+
|
343 |
+
2. A verbose answer that includes:
|
344 |
+
- The complete answer with all relevant details
|
345 |
+
- Explanations and reasoning
|
346 |
+
- Context and background information
|
347 |
+
- Source attribution where appropriate
|
348 |
+
|
349 |
+
Question: {question}
|
350 |
+
|
351 |
+
Please format your response as a JSON object with two keys:
|
352 |
+
- "succinct_answer": The concise answer following the rules above
|
353 |
+
- "verbose_answer": The detailed explanation with context"""
|
354 |
+
|
355 |
+
# Run inference on test samples
|
356 |
+
logger.info("Running inference on test samples...")
|
357 |
+
for i, example in enumerate(test_dataset):
|
358 |
+
try:
|
359 |
+
# Clear CUDA cache before each sample
|
360 |
+
torch.cuda.empty_cache()
|
361 |
+
|
362 |
+
# Format the task
|
363 |
+
task = format_task(example['Question'])
|
364 |
+
|
365 |
+
# Run the agent
|
366 |
+
result = agent.run(
|
367 |
+
task=task,
|
368 |
+
max_steps=3,
|
369 |
+
reset=True,
|
370 |
+
stream=False,
|
371 |
+
)
|
372 |
+
|
373 |
+
# Parse the result
|
374 |
+
import json
|
375 |
+
json_str = result[result.find("{"):result.rfind("}")+1]
|
376 |
+
parsed_result = json.loads(json_str)
|
377 |
+
answer = parsed_result["succinct_answer"]
|
378 |
+
|
379 |
+
logger.info(f"\nTest Sample {i+1}:")
|
380 |
+
logger.info(f"Question: {example['Question']}")
|
381 |
+
logger.info(f"Model Response: {answer}")
|
382 |
+
logger.info("-" * 80)
|
383 |
+
|
384 |
+
# Log memory usage after each sample
|
385 |
+
logger.info(f"Memory usage after sample {i+1}:")
|
386 |
+
logger.info(f"Allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
|
387 |
+
logger.info(f"Reserved: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")
|
388 |
+
|
389 |
+
except Exception as e:
|
390 |
+
logger.error(f"Error processing test sample {i+1}: {str(e)}")
|
391 |
+
continue
|
392 |
+
|
393 |
+
# Dump memory snapshot for analysis
|
394 |
+
torch.cuda.memory._dump_snapshot("memory_snapshot.pickle")
|
395 |
+
logger.info("Memory snapshot saved to memory_snapshot.pickle")
|
396 |
+
|
397 |
+
except Exception as e:
|
398 |
+
logger.error(f"Error during testing: {e}")
|
399 |
+
raise
|
400 |
+
|
401 |
except Exception as e:
|
402 |
logger.error(f"Error in main training process: {e}")
|
403 |
raise
|