Spaces:
Sleeping
Sleeping
Update requirements.txt to include sse-starlette dependency, enhance serve.py with additional imports for FastLanguageModel and FastVisionModel, and refactor train.py for improved organization and memory tracking during model training.
Browse files- requirements.txt +1 -0
- serve.py +10 -1
- train.py +29 -22
requirements.txt
CHANGED
|
@@ -28,6 +28,7 @@ python-dotenv>=1.0.0
|
|
| 28 |
requests>=2.32.3
|
| 29 |
sentence-transformers>=4.1.0
|
| 30 |
smolagents[litellm,telemetry,vllm]>=1.14.0
|
|
|
|
| 31 |
tensorboardX>=2.6.2.2
|
| 32 |
trl>=0.17.0
|
| 33 |
typing-extensions>=4.5.0
|
|
|
|
| 28 |
requests>=2.32.3
|
| 29 |
sentence-transformers>=4.1.0
|
| 30 |
smolagents[litellm,telemetry,vllm]>=1.14.0
|
| 31 |
+
sse-starlette>=2.3.4
|
| 32 |
tensorboardX>=2.6.2.2
|
| 33 |
trl>=0.17.0
|
| 34 |
typing-extensions>=4.5.0
|
serve.py
CHANGED
|
@@ -6,6 +6,16 @@ from pprint import pprint
|
|
| 6 |
from threading import Thread
|
| 7 |
from typing import Any, Dict, List
|
| 8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
from fastapi import FastAPI, Request
|
| 10 |
from openai.types.chat.chat_completion import ChatCompletion
|
| 11 |
from openai.types.chat.chat_completion import Choice as ChatCompletionChoice
|
|
@@ -20,7 +30,6 @@ from sse_starlette import EventSourceResponse
|
|
| 20 |
from starlette.responses import JSONResponse
|
| 21 |
from transformers.generation.streamers import AsyncTextIteratorStreamer
|
| 22 |
from transformers.image_utils import load_image
|
| 23 |
-
from unsloth import FastVisionModel
|
| 24 |
|
| 25 |
dtype = (
|
| 26 |
None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
|
|
|
|
| 6 |
from threading import Thread
|
| 7 |
from typing import Any, Dict, List
|
| 8 |
|
| 9 |
+
# isort: off
|
| 10 |
+
from unsloth import (
|
| 11 |
+
FastLanguageModel,
|
| 12 |
+
FastVisionModel,
|
| 13 |
+
is_bfloat16_supported,
|
| 14 |
+
) # noqa: E402
|
| 15 |
+
from unsloth.chat_templates import get_chat_template # noqa: E402
|
| 16 |
+
|
| 17 |
+
# isort: on
|
| 18 |
+
|
| 19 |
from fastapi import FastAPI, Request
|
| 20 |
from openai.types.chat.chat_completion import ChatCompletion
|
| 21 |
from openai.types.chat.chat_completion import Choice as ChatCompletionChoice
|
|
|
|
| 30 |
from starlette.responses import JSONResponse
|
| 31 |
from transformers.generation.streamers import AsyncTextIteratorStreamer
|
| 32 |
from transformers.image_utils import load_image
|
|
|
|
| 33 |
|
| 34 |
dtype = (
|
| 35 |
None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
|
train.py
CHANGED
|
@@ -28,6 +28,9 @@ from unsloth.chat_templates import get_chat_template # noqa: E402
|
|
| 28 |
|
| 29 |
# isort: on
|
| 30 |
|
|
|
|
|
|
|
|
|
|
| 31 |
from datasets import (
|
| 32 |
Dataset,
|
| 33 |
DatasetDict,
|
|
@@ -35,20 +38,19 @@ from datasets import (
|
|
| 35 |
IterableDatasetDict,
|
| 36 |
load_dataset,
|
| 37 |
)
|
|
|
|
|
|
|
|
|
|
| 38 |
from transformers import (
|
|
|
|
| 39 |
AutoTokenizer,
|
| 40 |
DataCollatorForLanguageModeling,
|
| 41 |
Trainer,
|
| 42 |
TrainingArguments,
|
| 43 |
-
AutoModelForCausalLM,
|
| 44 |
)
|
| 45 |
from trl import SFTTrainer
|
| 46 |
-
|
| 47 |
-
from smolagents import CodeAgent, Model, TransformersModel, VLLMModel
|
| 48 |
from tools.smart_search.tool import SmartSearchTool
|
| 49 |
-
from smolagents.monitoring import LogLevel
|
| 50 |
-
import torch
|
| 51 |
-
import os
|
| 52 |
|
| 53 |
|
| 54 |
# Setup logging
|
|
@@ -259,13 +261,11 @@ def main(cfg: DictConfig) -> None:
|
|
| 259 |
# Save model
|
| 260 |
logger.info(f"Saving final model to {cfg.output.dir}...")
|
| 261 |
trainer.save_model(cfg.output.dir)
|
| 262 |
-
|
| 263 |
# Save model in VLLM format
|
| 264 |
logger.info("Saving model in VLLM format...")
|
| 265 |
model.save_pretrained_merged(
|
| 266 |
-
cfg.output.dir,
|
| 267 |
-
tokenizer,
|
| 268 |
-
save_method="merged_16bit"
|
| 269 |
)
|
| 270 |
|
| 271 |
# Print final metrics
|
|
@@ -284,10 +284,12 @@ def main(cfg: DictConfig) -> None:
|
|
| 284 |
try:
|
| 285 |
# Enable memory history tracking
|
| 286 |
torch.cuda.memory._record_memory_history()
|
| 287 |
-
|
| 288 |
# Set memory allocation configuration
|
| 289 |
-
os.environ[
|
| 290 |
-
|
|
|
|
|
|
|
| 291 |
# Load test dataset
|
| 292 |
test_dataset = load_dataset(
|
| 293 |
cfg.test_dataset.name,
|
|
@@ -358,10 +360,10 @@ Please format your response as a JSON object with two keys:
|
|
| 358 |
try:
|
| 359 |
# Clear CUDA cache before each sample
|
| 360 |
torch.cuda.empty_cache()
|
| 361 |
-
|
| 362 |
# Format the task
|
| 363 |
-
task = format_task(example[
|
| 364 |
-
|
| 365 |
# Run the agent
|
| 366 |
result = agent.run(
|
| 367 |
task=task,
|
|
@@ -372,20 +374,25 @@ Please format your response as a JSON object with two keys:
|
|
| 372 |
|
| 373 |
# Parse the result
|
| 374 |
import json
|
| 375 |
-
|
|
|
|
| 376 |
parsed_result = json.loads(json_str)
|
| 377 |
answer = parsed_result["succinct_answer"]
|
| 378 |
-
|
| 379 |
logger.info(f"\nTest Sample {i+1}:")
|
| 380 |
logger.info(f"Question: {example['Question']}")
|
| 381 |
logger.info(f"Model Response: {answer}")
|
| 382 |
logger.info("-" * 80)
|
| 383 |
-
|
| 384 |
# Log memory usage after each sample
|
| 385 |
logger.info(f"Memory usage after sample {i+1}:")
|
| 386 |
-
logger.info(
|
| 387 |
-
|
| 388 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 389 |
except Exception as e:
|
| 390 |
logger.error(f"Error processing test sample {i+1}: {str(e)}")
|
| 391 |
continue
|
|
|
|
| 28 |
|
| 29 |
# isort: on
|
| 30 |
|
| 31 |
+
import os
|
| 32 |
+
|
| 33 |
+
import torch
|
| 34 |
from datasets import (
|
| 35 |
Dataset,
|
| 36 |
DatasetDict,
|
|
|
|
| 38 |
IterableDatasetDict,
|
| 39 |
load_dataset,
|
| 40 |
)
|
| 41 |
+
from peft import PeftModel
|
| 42 |
+
from smolagents import CodeAgent, Model, TransformersModel, VLLMModel
|
| 43 |
+
from smolagents.monitoring import LogLevel
|
| 44 |
from transformers import (
|
| 45 |
+
AutoModelForCausalLM,
|
| 46 |
AutoTokenizer,
|
| 47 |
DataCollatorForLanguageModeling,
|
| 48 |
Trainer,
|
| 49 |
TrainingArguments,
|
|
|
|
| 50 |
)
|
| 51 |
from trl import SFTTrainer
|
| 52 |
+
|
|
|
|
| 53 |
from tools.smart_search.tool import SmartSearchTool
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
|
| 56 |
# Setup logging
|
|
|
|
| 261 |
# Save model
|
| 262 |
logger.info(f"Saving final model to {cfg.output.dir}...")
|
| 263 |
trainer.save_model(cfg.output.dir)
|
| 264 |
+
|
| 265 |
# Save model in VLLM format
|
| 266 |
logger.info("Saving model in VLLM format...")
|
| 267 |
model.save_pretrained_merged(
|
| 268 |
+
cfg.output.dir, tokenizer, save_method="merged_16bit"
|
|
|
|
|
|
|
| 269 |
)
|
| 270 |
|
| 271 |
# Print final metrics
|
|
|
|
| 284 |
try:
|
| 285 |
# Enable memory history tracking
|
| 286 |
torch.cuda.memory._record_memory_history()
|
| 287 |
+
|
| 288 |
# Set memory allocation configuration
|
| 289 |
+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = (
|
| 290 |
+
"expandable_segments:True,max_split_size_mb:128"
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
# Load test dataset
|
| 294 |
test_dataset = load_dataset(
|
| 295 |
cfg.test_dataset.name,
|
|
|
|
| 360 |
try:
|
| 361 |
# Clear CUDA cache before each sample
|
| 362 |
torch.cuda.empty_cache()
|
| 363 |
+
|
| 364 |
# Format the task
|
| 365 |
+
task = format_task(example["Question"])
|
| 366 |
+
|
| 367 |
# Run the agent
|
| 368 |
result = agent.run(
|
| 369 |
task=task,
|
|
|
|
| 374 |
|
| 375 |
# Parse the result
|
| 376 |
import json
|
| 377 |
+
|
| 378 |
+
json_str = result[result.find("{") : result.rfind("}") + 1]
|
| 379 |
parsed_result = json.loads(json_str)
|
| 380 |
answer = parsed_result["succinct_answer"]
|
| 381 |
+
|
| 382 |
logger.info(f"\nTest Sample {i+1}:")
|
| 383 |
logger.info(f"Question: {example['Question']}")
|
| 384 |
logger.info(f"Model Response: {answer}")
|
| 385 |
logger.info("-" * 80)
|
| 386 |
+
|
| 387 |
# Log memory usage after each sample
|
| 388 |
logger.info(f"Memory usage after sample {i+1}:")
|
| 389 |
+
logger.info(
|
| 390 |
+
f"Allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB"
|
| 391 |
+
)
|
| 392 |
+
logger.info(
|
| 393 |
+
f"Reserved: {torch.cuda.memory_reserved() / 1024**2:.2f} MB"
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
except Exception as e:
|
| 397 |
logger.error(f"Error processing test sample {i+1}: {str(e)}")
|
| 398 |
continue
|