hf-train-frontend / install_requirements.py
George-API's picture
Upload folder using huggingface_hub
bbd5ba9 verified
#!/usr/bin/env python
# coding=utf-8
"""
Script to install requirements in the correct order for the Phi-4 training project.
This ensures base requirements are installed first, followed by additional requirements.
"""
import os
import sys
import subprocess
import argparse
import logging
from pathlib import Path
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
handlers=[logging.StreamHandler(sys.stdout)]
)
logger = logging.getLogger(__name__)
def install_requirements(include_flash=False):
"""Install requirements from the consolidated requirements file."""
current_dir = Path(__file__).parent
req_path = current_dir / "requirements.txt"
if not req_path.exists():
logger.error(f"Requirements file not found: {req_path}")
return False
logger.info("Installing dependencies from consolidated requirements file...")
try:
# Install all requirements
logger.info(f"Installing requirements from {req_path}")
subprocess.run([sys.executable, "-m", "pip", "install", "-r", str(req_path)],
check=True)
logger.info("Main requirements installed successfully")
# Optionally install flash-attention
if include_flash:
logger.info("Installing flash-attention...")
subprocess.run([sys.executable, "-m", "pip", "install", "flash-attn==2.5.2", "--no-build-isolation"],
check=True)
logger.info("Flash-attention installed successfully")
logger.info("All required packages installed successfully!")
return True
except subprocess.CalledProcessError as e:
logger.error(f"Error installing dependencies: {str(e)}")
return False
def main():
parser = argparse.ArgumentParser(description="Install requirements for Phi-4 training")
parser.add_argument("--flash", action="store_true", help="Also install flash-attention (optional)")
args = parser.parse_args()
success = install_requirements(include_flash=args.flash)
if success:
logger.info("Installation completed successfully!")
else:
logger.error("Installation failed. Please check the logs for details.")
sys.exit(1)
if __name__ == "__main__":
main()