mjschock commited on
Commit
518aafe
·
unverified ·
1 Parent(s): b21080c

Add serve.py for model deployment and API integration, update requirements.txt for smolagents with vllm support, and enhance .gitignore to exclude memory snapshot files. Additionally, implement testing configuration in config.yaml and modify train.py for memory tracking and model saving in VLLM format.

Browse files
Files changed (5) hide show
  1. .gitignore +1 -0
  2. conf/config.yaml +9 -0
  3. requirements.txt +1 -1
  4. serve.py +365 -0
  5. train.py +149 -7
.gitignore CHANGED
@@ -1,6 +1,7 @@
1
  .env
2
  logs
3
  lora_model
 
4
  outputs
5
  __pycache__
6
  .pytest_cache
 
1
  .env
2
  logs
3
  lora_model
4
+ memory_snapshot.pickle
5
  outputs
6
  __pycache__
7
  .pytest_cache
conf/config.yaml CHANGED
@@ -70,3 +70,12 @@ output:
70
 
71
  # Training control
72
  train: false
 
 
 
 
 
 
 
 
 
 
70
 
71
  # Training control
72
  train: false
73
+
74
+ # Testing configuration
75
+ test: true # Whether to run testing after training
76
+ test_dataset:
77
+ name: "gaia-benchmark/GAIA"
78
+ config: "2023_level1" # Use level 1 questions for testing
79
+ split: "test" # Use test split for testing
80
+ max_samples: 10 # Number of samples to test on
81
+ max_length: 2048 # Maximum sequence length for testing
requirements.txt CHANGED
@@ -27,7 +27,7 @@ pytest-cov>=6.1.1
27
  python-dotenv>=1.0.0
28
  requests>=2.32.3
29
  sentence-transformers>=4.1.0
30
- smolagents[litellm,telemetry]>=1.14.0
31
  tensorboardX>=2.6.2.2
32
  trl>=0.17.0
33
  typing-extensions>=4.5.0
 
27
  python-dotenv>=1.0.0
28
  requests>=2.32.3
29
  sentence-transformers>=4.1.0
30
+ smolagents[litellm,telemetry,vllm]>=1.14.0
31
  tensorboardX>=2.6.2.2
32
  trl>=0.17.0
33
  typing-extensions>=4.5.0
serve.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import logging
3
+ import os
4
+ import time
5
+ from pprint import pprint
6
+ from threading import Thread
7
+ from typing import Any, Dict, List
8
+
9
+ from fastapi import FastAPI, Request
10
+ from openai.types.chat.chat_completion import ChatCompletion
11
+ from openai.types.chat.chat_completion import Choice as ChatCompletionChoice
12
+ from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
13
+ from openai.types.chat.chat_completion_chunk import Choice as ChatCompletionChunkChoice
14
+ from openai.types.chat.chat_completion_chunk import ChoiceDelta
15
+ from openai.types.chat.chat_completion_message import ChatCompletionMessage
16
+ from openai.types.chat.completion_create_params import CompletionCreateParams
17
+ from pydantic import TypeAdapter
18
+ from ray import serve
19
+ from sse_starlette import EventSourceResponse
20
+ from starlette.responses import JSONResponse
21
+ from transformers.generation.streamers import AsyncTextIteratorStreamer
22
+ from transformers.image_utils import load_image
23
+ from unsloth import FastVisionModel
24
+
25
+ dtype = (
26
+ None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
27
+ )
28
+ load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.
29
+ max_seq_length = 2048 # Supports RoPE Scaling interally, so choose any!
30
+ # max_seq_length = 4096 # Choose any! We auto support RoPE Scaling internally!
31
+
32
+
33
+ logger = logging.getLogger("ray.serve")
34
+
35
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
36
+
37
+ app = FastAPI()
38
+
39
+ # middlewares = [
40
+ # middleware
41
+ # for middleware in ConnexionMiddleware.default_middlewares
42
+ # if middleware is not SecurityMiddleware
43
+ # ]
44
+
45
+ # connexion_app = AsyncApp(import_name=__name__, middlewares=middlewares)
46
+
47
+ # connexion_app.add_api(
48
+ # # "api/openai/v1/openapi/openapi.yaml",
49
+ # "api/v1/openapi/openapi.yaml",
50
+ # # base_path="/openai/v1",
51
+ # base_path="/v1",
52
+ # pythonic_params=True,
53
+ # resolver_error=501,
54
+ # )
55
+
56
+ # # fastapi_app.mount("/api", ConnexionMiddleware(app=connexion_app, import_name=__name__))
57
+ # # app.mount("/api", ConnexionMiddleware(app=connexion_app, import_name=__name__))
58
+ # app.mount(
59
+ # "/",
60
+ # ConnexionMiddleware(
61
+ # app=connexion_app,
62
+ # import_name=__name__,
63
+ # # middlewares=middlewares,
64
+ # ),
65
+ # )
66
+
67
+
68
+ @serve.deployment(
69
+ autoscaling_config={
70
+ # https://docs.ray.io/en/latest/serve/advanced-guides/advanced-autoscaling.html#required-define-upper-and-lower-autoscaling-limits
71
+ "max_replicas": 1,
72
+ "min_replicas": 1, # TOOD: set to 0
73
+ "target_ongoing_requests": 2, # https://docs.ray.io/en/latest/serve/advanced-guides/advanced-autoscaling.html#target-ongoing-requests-default-2
74
+ },
75
+ max_ongoing_requests=5, # https://docs.ray.io/en/latest/serve/advanced-guides/advanced-autoscaling.html#max-ongoing-requests-default-5
76
+ ray_actor_options={"num_gpus": 1},
77
+ )
78
+ @serve.ingress(app)
79
+ class ModelDeployment:
80
+ def __init__(
81
+ self,
82
+ model_name: str,
83
+ ):
84
+ self.model_name = model_name
85
+
86
+ model, processor = FastVisionModel.from_pretrained(
87
+ load_in_4bit=load_in_4bit,
88
+ max_seq_length=max_seq_length,
89
+ model_name=self.model_name,
90
+ )
91
+
92
+ with open("chat_template.txt", "r") as f:
93
+ processor.chat_template = f.read()
94
+ processor.tokenizer.chat_template = processor.chat_template
95
+
96
+ FastVisionModel.for_inference(model) # Enable native 2x faster inference
97
+
98
+ self.model = model
99
+ self.processor = processor
100
+
101
+ def reconfigure(self, config: Dict[str, Any]):
102
+ print("=== reconfigure ===")
103
+ print("config:")
104
+ print(config)
105
+ # https://docs.ray.io/en/latest/serve/production-guide/config.html#dynamically-change-parameters-without-restarting-replicas-user-config
106
+
107
+ @app.post("/v1/chat/completions")
108
+ async def create_chat_completion(self, body: dict, raw_request: Request):
109
+ """Creates a model response for the given chat conversation. Learn more in the [text generation](/docs/guides/text-generation), [vision](/docs/guides/vision), and [audio](/docs/guides/audio) guides. Parameter support can differ depending on the model used to generate the response, particularly for newer reasoning models. Parameters that are only supported for reasoning models are noted below. For the current state of unsupported parameters in reasoning models, [refer to the reasoning guide](/docs/guides/reasoning).
110
+
111
+ # noqa: E501
112
+
113
+ :param create_chat_completion_request:
114
+ :type create_chat_completion_request: dict | bytes
115
+
116
+ :rtype: Union[CreateChatCompletionResponse, Tuple[CreateChatCompletionResponse, int], Tuple[CreateChatCompletionResponse, int, Dict[str, str]]
117
+ """
118
+ print("=== create_chat_completion ===")
119
+
120
+ print("body:")
121
+ pprint(body)
122
+
123
+ ta = TypeAdapter(CompletionCreateParams)
124
+
125
+ print("ta.validate_python...")
126
+ pprint(ta.validate_python(body))
127
+
128
+ max_new_tokens = body.get("max_completion_tokens", body.get("max_tokens"))
129
+ messages = body.get("messages")
130
+ model_name = body.get("model")
131
+ stream = body.get("stream", False)
132
+ temperature = body.get("temperature")
133
+ tools = body.get("tools")
134
+
135
+ images = []
136
+
137
+ for message in messages:
138
+ for content in message["content"]:
139
+ if "type" in content and content["type"] == "image_url":
140
+ image_url = content["image_url"]["url"]
141
+ image = load_image(image_url)
142
+ images.append(image)
143
+
144
+ content["type"] = "image"
145
+ del content["image_url"]
146
+
147
+ images = images if images else None
148
+
149
+ if model_name != self.model_name:
150
+ # adapter_path = model_name
151
+ # self.model.load_adapter(adapter_path)
152
+
153
+ return JSONResponse(content={"error": "Model not found"}, status_code=404)
154
+
155
+ prompt = self.processor.apply_chat_template(
156
+ add_generation_prompt=True,
157
+ conversation=messages,
158
+ # documents=documents,
159
+ tools=tools,
160
+ )
161
+
162
+ print("prompt:")
163
+ print(prompt)
164
+
165
+ inputs = self.processor(text=prompt, images=images, return_tensors="pt")
166
+ inputs = inputs.to(self.model.device)
167
+ input_ids = inputs.input_ids
168
+
169
+ class GeneratorThread(Thread):
170
+ """Thread to generate completions in the background."""
171
+
172
+ def __init__(self, model, **generation_kwargs):
173
+ super().__init__()
174
+
175
+ self.chat_completion = None
176
+ self.generation_kwargs = generation_kwargs
177
+ self.model = model
178
+
179
+ def run(self):
180
+ import torch
181
+ import torch._dynamo.config
182
+
183
+ try:
184
+ try:
185
+ self.generated_ids = self.model.generate(
186
+ **self.generation_kwargs
187
+ )
188
+
189
+ except torch._dynamo.exc.BackendCompilerFailed as e:
190
+ print(e)
191
+ print("Disabling dynamo...")
192
+
193
+ torch._dynamo.config.disable = True
194
+
195
+ self.generated_ids = self.model.generate(
196
+ **self.generation_kwargs
197
+ )
198
+
199
+ except Exception as e:
200
+ print(e)
201
+ print("Warning: Exception in GeneratorThread")
202
+ self.generated_ids = []
203
+
204
+ def join(self, timeout=None):
205
+ super().join()
206
+
207
+ return self.generated_ids
208
+
209
+ decode_kwargs = dict(skip_special_tokens=True)
210
+
211
+ streamer = (
212
+ AsyncTextIteratorStreamer(
213
+ self.processor,
214
+ skip_prompt=True,
215
+ **decode_kwargs,
216
+ )
217
+ if stream
218
+ else None
219
+ )
220
+
221
+ generation_kwargs = dict(
222
+ **inputs,
223
+ max_new_tokens=max_new_tokens,
224
+ streamer=streamer,
225
+ temperature=temperature,
226
+ use_cache=True,
227
+ )
228
+
229
+ thread = GeneratorThread(self.model, **generation_kwargs)
230
+ thread.start()
231
+
232
+ if stream:
233
+
234
+ async def event_publisher():
235
+ i = 0
236
+
237
+ try:
238
+ async for new_text in streamer:
239
+ print("new_text:")
240
+ print(new_text)
241
+
242
+ choices: List[ChatCompletionChunkChoice] = [
243
+ ChatCompletionChunkChoice(
244
+ _request_id=None,
245
+ delta=ChoiceDelta(
246
+ _request_id=None,
247
+ content=new_text,
248
+ function_call=None,
249
+ refusal=None,
250
+ role="assistant",
251
+ tool_calls=None,
252
+ ),
253
+ finish_reason=None,
254
+ index=0,
255
+ logprobs=None,
256
+ )
257
+ ]
258
+
259
+ chat_completion_chunk = ChatCompletionChunk(
260
+ _request_id=None,
261
+ choices=choices,
262
+ created=int(time.time()),
263
+ id=str(i),
264
+ model=model_name,
265
+ object="chat.completion.chunk",
266
+ service_tier=None,
267
+ system_fingerprint=None,
268
+ usage=None,
269
+ )
270
+
271
+ yield chat_completion_chunk.model_dump_json()
272
+
273
+ i += 1
274
+
275
+ except asyncio.CancelledError as e:
276
+ print("Disconnected from client (via refresh/close)")
277
+ raise e
278
+
279
+ except Exception as e:
280
+ print(f"Exception: {e}")
281
+ raise e
282
+
283
+ return EventSourceResponse(event_publisher())
284
+
285
+ generated_ids = thread.join()
286
+ input_length = input_ids.shape[1]
287
+
288
+ batch_decoded_outputs = self.processor.batch_decode(
289
+ generated_ids[:, input_length:],
290
+ skip_special_tokens=True,
291
+ )
292
+
293
+ choices: List[ChatCompletionChoice] = []
294
+
295
+ for i, response in enumerate(batch_decoded_outputs):
296
+ print("response:")
297
+ print(response)
298
+
299
+ # try:
300
+ # response = json.loads(response)
301
+
302
+ # finish_reason: str = response.get("finish_reason")
303
+ # tool_calls_json = response.get("tool_calls")
304
+ # tool_calls: List[ToolCall] = []
305
+
306
+ # for tool_call_json in tool_calls_json:
307
+ # tool_call = ToolCall(
308
+ # function=FunctionToolCallArguments(
309
+ # arguments=tool_call_json.get("arguments"),
310
+ # name=tool_call_json.get("name"),
311
+ # ),
312
+ # id=tool_call_json.get("id"),
313
+ # type="function",
314
+ # )
315
+
316
+ # tool_calls.append(tool_call)
317
+
318
+ # message: ChatMessage = ChatMessage(
319
+ # role="assistant",
320
+ # tool_calls=tool_calls,
321
+ # )
322
+
323
+ # except json.JSONDecodeError:
324
+ # finish_reason: str = "stop"
325
+ # message: ChatMessage = ChatMessage(
326
+ # role="assistant",
327
+ # content=response,
328
+ # )
329
+
330
+ message = ChatCompletionMessage(
331
+ audio=None,
332
+ content=response,
333
+ refusal=None,
334
+ role="assistant",
335
+ tool_calls=None,
336
+ )
337
+
338
+ choices.append(
339
+ ChatCompletionChoice(
340
+ index=i,
341
+ finish_reason="stop",
342
+ logprobs=None,
343
+ message=message,
344
+ )
345
+ )
346
+
347
+ chat_completion = ChatCompletion(
348
+ choices=choices,
349
+ created=int(time.time()),
350
+ id="1",
351
+ model=model_name,
352
+ object="chat.completion",
353
+ service_tier=None,
354
+ system_fingerprint=None,
355
+ usage=None,
356
+ )
357
+
358
+ return chat_completion.model_dump(mode="json")
359
+
360
+
361
+ def build_app(cli_args: Dict[str, str]) -> serve.Application:
362
+ """Builds the Serve app based on CLI arguments."""
363
+ return ModelDeployment.options().bind(
364
+ cli_args.get("model_name"),
365
+ )
train.py CHANGED
@@ -40,8 +40,16 @@ from transformers import (
40
  DataCollatorForLanguageModeling,
41
  Trainer,
42
  TrainingArguments,
 
43
  )
44
  from trl import SFTTrainer
 
 
 
 
 
 
 
45
 
46
  # Setup logging
47
  def setup_logging():
@@ -130,7 +138,9 @@ def load_and_format_dataset(
130
  logger.info(f"Dataset loaded successfully. Size: {len(dataset)} examples")
131
 
132
  # Split into train and validation sets
133
- dataset = dataset.train_test_split(test_size=cfg.dataset.validation_split, seed=cfg.dataset.seed)
 
 
134
  logger.info(
135
  f"Dataset split into train ({len(dataset['train'])} examples) and validation ({len(dataset['test'])} examples) sets"
136
  )
@@ -188,10 +198,12 @@ def create_trainer(
188
  # Create TrainingArguments from config
189
  training_args_dict = OmegaConf.to_container(cfg.training.args, resolve=True)
190
  # Add dynamic precision settings
191
- training_args_dict.update({
192
- "fp16": not is_bfloat16_supported(),
193
- "bf16": is_bfloat16_supported(),
194
- })
 
 
195
  training_args = TrainingArguments(**training_args_dict)
196
 
197
  # Create data collator from config
@@ -202,7 +214,7 @@ def create_trainer(
202
 
203
  # Create SFT config without data_collator to avoid duplication
204
  sft_config = OmegaConf.to_container(cfg.training.sft, resolve=True)
205
- sft_config.pop('data_collator', None) # Remove data_collator from config
206
 
207
  trainer = SFTTrainer(
208
  model=model,
@@ -247,15 +259,145 @@ def main(cfg: DictConfig) -> None:
247
  # Save model
248
  logger.info(f"Saving final model to {cfg.output.dir}...")
249
  trainer.save_model(cfg.output.dir)
 
 
 
 
 
 
 
 
250
 
251
  # Print final metrics
252
  final_metrics = trainer.state.log_history[-1]
253
  logger.info("\nTraining completed!")
254
  logger.info(f"Final training loss: {final_metrics.get('loss', 'N/A')}")
255
- logger.info(f"Final validation loss: {final_metrics.get('eval_loss', 'N/A')}")
 
 
256
  else:
257
  logger.info("Training skipped as train=False")
258
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
  except Exception as e:
260
  logger.error(f"Error in main training process: {e}")
261
  raise
 
40
  DataCollatorForLanguageModeling,
41
  Trainer,
42
  TrainingArguments,
43
+ AutoModelForCausalLM,
44
  )
45
  from trl import SFTTrainer
46
+ from peft import PeftModel
47
+ from smolagents import CodeAgent, Model, TransformersModel, VLLMModel
48
+ from tools.smart_search.tool import SmartSearchTool
49
+ from smolagents.monitoring import LogLevel
50
+ import torch
51
+ import os
52
+
53
 
54
  # Setup logging
55
  def setup_logging():
 
138
  logger.info(f"Dataset loaded successfully. Size: {len(dataset)} examples")
139
 
140
  # Split into train and validation sets
141
+ dataset = dataset.train_test_split(
142
+ test_size=cfg.dataset.validation_split, seed=cfg.dataset.seed
143
+ )
144
  logger.info(
145
  f"Dataset split into train ({len(dataset['train'])} examples) and validation ({len(dataset['test'])} examples) sets"
146
  )
 
198
  # Create TrainingArguments from config
199
  training_args_dict = OmegaConf.to_container(cfg.training.args, resolve=True)
200
  # Add dynamic precision settings
201
+ training_args_dict.update(
202
+ {
203
+ "fp16": not is_bfloat16_supported(),
204
+ "bf16": is_bfloat16_supported(),
205
+ }
206
+ )
207
  training_args = TrainingArguments(**training_args_dict)
208
 
209
  # Create data collator from config
 
214
 
215
  # Create SFT config without data_collator to avoid duplication
216
  sft_config = OmegaConf.to_container(cfg.training.sft, resolve=True)
217
+ sft_config.pop("data_collator", None) # Remove data_collator from config
218
 
219
  trainer = SFTTrainer(
220
  model=model,
 
259
  # Save model
260
  logger.info(f"Saving final model to {cfg.output.dir}...")
261
  trainer.save_model(cfg.output.dir)
262
+
263
+ # Save model in VLLM format
264
+ logger.info("Saving model in VLLM format...")
265
+ model.save_pretrained_merged(
266
+ cfg.output.dir,
267
+ tokenizer,
268
+ save_method="merged_16bit"
269
+ )
270
 
271
  # Print final metrics
272
  final_metrics = trainer.state.log_history[-1]
273
  logger.info("\nTraining completed!")
274
  logger.info(f"Final training loss: {final_metrics.get('loss', 'N/A')}")
275
+ logger.info(
276
+ f"Final validation loss: {final_metrics.get('eval_loss', 'N/A')}"
277
+ )
278
  else:
279
  logger.info("Training skipped as train=False")
280
 
281
+ # Test if requested
282
+ if cfg.test:
283
+ logger.info("\nStarting testing...")
284
+ try:
285
+ # Enable memory history tracking
286
+ torch.cuda.memory._record_memory_history()
287
+
288
+ # Set memory allocation configuration
289
+ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True,max_split_size_mb:128'
290
+
291
+ # Load test dataset
292
+ test_dataset = load_dataset(
293
+ cfg.test_dataset.name,
294
+ cfg.test_dataset.config,
295
+ split=cfg.test_dataset.split,
296
+ trust_remote_code=True,
297
+ )
298
+ logger.info(f"Loaded test dataset with {len(test_dataset)} examples")
299
+ logger.info(f"Dataset features: {test_dataset.features}")
300
+
301
+ # Clear CUDA cache before loading model
302
+ torch.cuda.empty_cache()
303
+
304
+ # Initialize model
305
+ model: Model = Model(
306
+ model_id=cfg.model.name,
307
+ # model_id=cfg.output.dir,
308
+ )
309
+
310
+ # model: Model = TransformersModel(
311
+ # model_id=cfg.model.name,
312
+ # # model_id=cfg.output.dir,
313
+ # )
314
+
315
+ # model: Model = VLLMModel(
316
+ # model_id=cfg.model.name,
317
+ # # model_id=cfg.output.dir,
318
+ # )
319
+
320
+ # Create CodeAgent with SmartSearchTool
321
+ agent = CodeAgent(
322
+ model=model,
323
+ tools=[SmartSearchTool()],
324
+ verbosity_level=LogLevel.ERROR,
325
+ )
326
+
327
+ # Format task to get succinct answer
328
+ def format_task(question):
329
+ return f"""Please provide two answers to the following question:
330
+
331
+ 1. A succinct answer that follows these rules:
332
+ - Contains ONLY the answer, nothing else
333
+ - Does not repeat the question
334
+ - Does not include explanations, reasoning, or context
335
+ - Does not include source attribution or references
336
+ - Does not use phrases like "The answer is" or "I found that"
337
+ - Does not include formatting, bullet points, or line breaks
338
+ - If the answer is a number, return only the number
339
+ - If the answer requires multiple items, separate them with commas
340
+ - If the answer requires ordering, maintain the specified order
341
+ - Uses the most direct and succinct form possible
342
+
343
+ 2. A verbose answer that includes:
344
+ - The complete answer with all relevant details
345
+ - Explanations and reasoning
346
+ - Context and background information
347
+ - Source attribution where appropriate
348
+
349
+ Question: {question}
350
+
351
+ Please format your response as a JSON object with two keys:
352
+ - "succinct_answer": The concise answer following the rules above
353
+ - "verbose_answer": The detailed explanation with context"""
354
+
355
+ # Run inference on test samples
356
+ logger.info("Running inference on test samples...")
357
+ for i, example in enumerate(test_dataset):
358
+ try:
359
+ # Clear CUDA cache before each sample
360
+ torch.cuda.empty_cache()
361
+
362
+ # Format the task
363
+ task = format_task(example['Question'])
364
+
365
+ # Run the agent
366
+ result = agent.run(
367
+ task=task,
368
+ max_steps=3,
369
+ reset=True,
370
+ stream=False,
371
+ )
372
+
373
+ # Parse the result
374
+ import json
375
+ json_str = result[result.find("{"):result.rfind("}")+1]
376
+ parsed_result = json.loads(json_str)
377
+ answer = parsed_result["succinct_answer"]
378
+
379
+ logger.info(f"\nTest Sample {i+1}:")
380
+ logger.info(f"Question: {example['Question']}")
381
+ logger.info(f"Model Response: {answer}")
382
+ logger.info("-" * 80)
383
+
384
+ # Log memory usage after each sample
385
+ logger.info(f"Memory usage after sample {i+1}:")
386
+ logger.info(f"Allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
387
+ logger.info(f"Reserved: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")
388
+
389
+ except Exception as e:
390
+ logger.error(f"Error processing test sample {i+1}: {str(e)}")
391
+ continue
392
+
393
+ # Dump memory snapshot for analysis
394
+ torch.cuda.memory._dump_snapshot("memory_snapshot.pickle")
395
+ logger.info("Memory snapshot saved to memory_snapshot.pickle")
396
+
397
+ except Exception as e:
398
+ logger.error(f"Error during testing: {e}")
399
+ raise
400
+
401
  except Exception as e:
402
  logger.error(f"Error in main training process: {e}")
403
  raise