mjschock's picture
Update requirements.txt to include sse-starlette dependency, enhance serve.py with additional imports for FastLanguageModel and FastVisionModel, and refactor train.py for improved organization and memory tracking during model training.
95d9fdc unverified
raw
history blame
14.2 kB
#!/usr/bin/env python3
"""
Fine-tuning script for SmolLM2-135M model using Unsloth.
This script demonstrates how to:
1. Install and configure Unsloth
2. Prepare and format training data
3. Configure and run the training process
4. Save and evaluate the model
To run this script:
1. Install dependencies: pip install -r requirements.txt
2. Run: python train.py
"""
import logging
import os
from datetime import datetime
from pathlib import Path
from typing import Union
import hydra
from omegaconf import DictConfig, OmegaConf
# isort: off
from unsloth import FastLanguageModel, is_bfloat16_supported # noqa: E402
from unsloth.chat_templates import get_chat_template # noqa: E402
# isort: on
import os
import torch
from datasets import (
Dataset,
DatasetDict,
IterableDataset,
IterableDatasetDict,
load_dataset,
)
from peft import PeftModel
from smolagents import CodeAgent, Model, TransformersModel, VLLMModel
from smolagents.monitoring import LogLevel
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
DataCollatorForLanguageModeling,
Trainer,
TrainingArguments,
)
from trl import SFTTrainer
from tools.smart_search.tool import SmartSearchTool
# Setup logging
def setup_logging():
"""Configure logging for the training process."""
# Create logs directory if it doesn't exist
log_dir = Path("logs")
log_dir.mkdir(exist_ok=True)
# Create a unique log file name with timestamp
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
log_file = log_dir / f"training_{timestamp}.log"
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[logging.FileHandler(log_file), logging.StreamHandler()],
)
logger = logging.getLogger(__name__)
logger.info(f"Logging initialized. Log file: {log_file}")
return logger
logger = setup_logging()
def install_dependencies():
"""Install required dependencies."""
logger.info("Installing dependencies...")
try:
os.system(
'pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"'
)
os.system("pip install --no-deps xformers trl peft accelerate bitsandbytes")
logger.info("Dependencies installed successfully")
except Exception as e:
logger.error(f"Error installing dependencies: {e}")
raise
def load_model(cfg: DictConfig) -> tuple[FastLanguageModel, AutoTokenizer]:
"""Load and configure the model."""
logger.info("Loading model and tokenizer...")
try:
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=cfg.model.name,
max_seq_length=cfg.model.max_seq_length,
dtype=cfg.model.dtype,
load_in_4bit=cfg.model.load_in_4bit,
)
logger.info("Base model loaded successfully")
# Configure LoRA
model = FastLanguageModel.get_peft_model(
model,
r=cfg.peft.r,
target_modules=cfg.peft.target_modules,
lora_alpha=cfg.peft.lora_alpha,
lora_dropout=cfg.peft.lora_dropout,
bias=cfg.peft.bias,
use_gradient_checkpointing=cfg.peft.use_gradient_checkpointing,
random_state=cfg.peft.random_state,
use_rslora=cfg.peft.use_rslora,
loftq_config=cfg.peft.loftq_config,
)
logger.info("LoRA configuration applied successfully")
return model, tokenizer
except Exception as e:
logger.error(f"Error loading model: {e}")
raise
def load_and_format_dataset(
tokenizer: AutoTokenizer,
cfg: DictConfig,
) -> tuple[
Union[DatasetDict, Dataset, IterableDatasetDict, IterableDataset], AutoTokenizer
]:
"""Load and format the training dataset."""
logger.info("Loading and formatting dataset...")
try:
# Load the code-act dataset
dataset = load_dataset("xingyaoww/code-act", split="codeact")
logger.info(f"Dataset loaded successfully. Size: {len(dataset)} examples")
# Split into train and validation sets
dataset = dataset.train_test_split(
test_size=cfg.dataset.validation_split, seed=cfg.dataset.seed
)
logger.info(
f"Dataset split into train ({len(dataset['train'])} examples) and validation ({len(dataset['test'])} examples) sets"
)
# Configure chat template
tokenizer = get_chat_template(
tokenizer,
chat_template="chatml", # Supports zephyr, chatml, mistral, llama, alpaca, vicuna, vicuna_old, unsloth
mapping={
"role": "from",
"content": "value",
"user": "human",
"assistant": "gpt",
}, # ShareGPT style
map_eos_token=True, # Maps <|im_end|> to </s> instead
)
logger.info("Chat template configured successfully")
def formatting_prompts_func(examples):
convos = examples["conversations"]
texts = [
tokenizer.apply_chat_template(
convo, tokenize=False, add_generation_prompt=False
)
for convo in convos
]
return {"text": texts}
# Apply formatting to both train and validation sets
dataset = DatasetDict(
{
"train": dataset["train"].map(formatting_prompts_func, batched=True),
"validation": dataset["test"].map(
formatting_prompts_func, batched=True
),
}
)
logger.info("Dataset formatting completed successfully")
return dataset, tokenizer
except Exception as e:
logger.error(f"Error loading/formatting dataset: {e}")
raise
def create_trainer(
model: FastLanguageModel,
tokenizer: AutoTokenizer,
dataset: Union[DatasetDict, Dataset, IterableDatasetDict, IterableDataset],
cfg: DictConfig,
) -> Trainer:
"""Create and configure the SFTTrainer."""
logger.info("Creating trainer...")
try:
# Create TrainingArguments from config
training_args_dict = OmegaConf.to_container(cfg.training.args, resolve=True)
# Add dynamic precision settings
training_args_dict.update(
{
"fp16": not is_bfloat16_supported(),
"bf16": is_bfloat16_supported(),
}
)
training_args = TrainingArguments(**training_args_dict)
# Create data collator from config
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
**cfg.training.sft.data_collator,
)
# Create SFT config without data_collator to avoid duplication
sft_config = OmegaConf.to_container(cfg.training.sft, resolve=True)
sft_config.pop("data_collator", None) # Remove data_collator from config
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=dataset["train"],
eval_dataset=dataset["validation"],
args=training_args,
data_collator=data_collator,
**sft_config,
)
logger.info("Trainer created successfully")
return trainer
except Exception as e:
logger.error(f"Error creating trainer: {e}")
raise
@hydra.main(version_base=None, config_path="conf", config_name="config")
def main(cfg: DictConfig) -> None:
"""Main training function."""
try:
logger.info("Starting training process...")
logger.info(f"Configuration:\n{OmegaConf.to_yaml(cfg)}")
# Install dependencies
install_dependencies()
# Load model and tokenizer
model, tokenizer = load_model(cfg)
# Load and prepare dataset
dataset, tokenizer = load_and_format_dataset(tokenizer, cfg)
# Create trainer
trainer: Trainer = create_trainer(model, tokenizer, dataset, cfg)
# Train if requested
if cfg.train:
logger.info("Starting training...")
trainer.train()
# Save model
logger.info(f"Saving final model to {cfg.output.dir}...")
trainer.save_model(cfg.output.dir)
# Save model in VLLM format
logger.info("Saving model in VLLM format...")
model.save_pretrained_merged(
cfg.output.dir, tokenizer, save_method="merged_16bit"
)
# Print final metrics
final_metrics = trainer.state.log_history[-1]
logger.info("\nTraining completed!")
logger.info(f"Final training loss: {final_metrics.get('loss', 'N/A')}")
logger.info(
f"Final validation loss: {final_metrics.get('eval_loss', 'N/A')}"
)
else:
logger.info("Training skipped as train=False")
# Test if requested
if cfg.test:
logger.info("\nStarting testing...")
try:
# Enable memory history tracking
torch.cuda.memory._record_memory_history()
# Set memory allocation configuration
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = (
"expandable_segments:True,max_split_size_mb:128"
)
# Load test dataset
test_dataset = load_dataset(
cfg.test_dataset.name,
cfg.test_dataset.config,
split=cfg.test_dataset.split,
trust_remote_code=True,
)
logger.info(f"Loaded test dataset with {len(test_dataset)} examples")
logger.info(f"Dataset features: {test_dataset.features}")
# Clear CUDA cache before loading model
torch.cuda.empty_cache()
# Initialize model
model: Model = Model(
model_id=cfg.model.name,
# model_id=cfg.output.dir,
)
# model: Model = TransformersModel(
# model_id=cfg.model.name,
# # model_id=cfg.output.dir,
# )
# model: Model = VLLMModel(
# model_id=cfg.model.name,
# # model_id=cfg.output.dir,
# )
# Create CodeAgent with SmartSearchTool
agent = CodeAgent(
model=model,
tools=[SmartSearchTool()],
verbosity_level=LogLevel.ERROR,
)
# Format task to get succinct answer
def format_task(question):
return f"""Please provide two answers to the following question:
1. A succinct answer that follows these rules:
- Contains ONLY the answer, nothing else
- Does not repeat the question
- Does not include explanations, reasoning, or context
- Does not include source attribution or references
- Does not use phrases like "The answer is" or "I found that"
- Does not include formatting, bullet points, or line breaks
- If the answer is a number, return only the number
- If the answer requires multiple items, separate them with commas
- If the answer requires ordering, maintain the specified order
- Uses the most direct and succinct form possible
2. A verbose answer that includes:
- The complete answer with all relevant details
- Explanations and reasoning
- Context and background information
- Source attribution where appropriate
Question: {question}
Please format your response as a JSON object with two keys:
- "succinct_answer": The concise answer following the rules above
- "verbose_answer": The detailed explanation with context"""
# Run inference on test samples
logger.info("Running inference on test samples...")
for i, example in enumerate(test_dataset):
try:
# Clear CUDA cache before each sample
torch.cuda.empty_cache()
# Format the task
task = format_task(example["Question"])
# Run the agent
result = agent.run(
task=task,
max_steps=3,
reset=True,
stream=False,
)
# Parse the result
import json
json_str = result[result.find("{") : result.rfind("}") + 1]
parsed_result = json.loads(json_str)
answer = parsed_result["succinct_answer"]
logger.info(f"\nTest Sample {i+1}:")
logger.info(f"Question: {example['Question']}")
logger.info(f"Model Response: {answer}")
logger.info("-" * 80)
# Log memory usage after each sample
logger.info(f"Memory usage after sample {i+1}:")
logger.info(
f"Allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB"
)
logger.info(
f"Reserved: {torch.cuda.memory_reserved() / 1024**2:.2f} MB"
)
except Exception as e:
logger.error(f"Error processing test sample {i+1}: {str(e)}")
continue
# Dump memory snapshot for analysis
torch.cuda.memory._dump_snapshot("memory_snapshot.pickle")
logger.info("Memory snapshot saved to memory_snapshot.pickle")
except Exception as e:
logger.error(f"Error during testing: {e}")
raise
except Exception as e:
logger.error(f"Error in main training process: {e}")
raise
if __name__ == "__main__":
main()