Spaces:
Sleeping
Sleeping
#!/usr/bin/env python | |
# coding=utf-8 | |
""" | |
Script to install requirements in the correct order for the Phi-4 training project. | |
This ensures base requirements are installed first, followed by additional requirements. | |
""" | |
import os | |
import sys | |
import subprocess | |
import argparse | |
import logging | |
from pathlib import Path | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format="%(asctime)s - %(levelname)s - %(message)s", | |
handlers=[logging.StreamHandler(sys.stdout)] | |
) | |
logger = logging.getLogger(__name__) | |
def install_requirements(include_flash=False): | |
"""Install requirements from the consolidated requirements file.""" | |
current_dir = Path(__file__).parent | |
req_path = current_dir / "requirements.txt" | |
if not req_path.exists(): | |
logger.error(f"Requirements file not found: {req_path}") | |
return False | |
logger.info("Installing dependencies from consolidated requirements file...") | |
try: | |
# Install all requirements | |
logger.info(f"Installing requirements from {req_path}") | |
subprocess.run([sys.executable, "-m", "pip", "install", "-r", str(req_path)], | |
check=True) | |
logger.info("Main requirements installed successfully") | |
# Optionally install flash-attention | |
if include_flash: | |
logger.info("Installing flash-attention...") | |
subprocess.run([sys.executable, "-m", "pip", "install", "flash-attn==2.5.2", "--no-build-isolation"], | |
check=True) | |
logger.info("Flash-attention installed successfully") | |
logger.info("All required packages installed successfully!") | |
return True | |
except subprocess.CalledProcessError as e: | |
logger.error(f"Error installing dependencies: {str(e)}") | |
return False | |
def main(): | |
parser = argparse.ArgumentParser(description="Install requirements for Phi-4 training") | |
parser.add_argument("--flash", action="store_true", help="Also install flash-attention (optional)") | |
args = parser.parse_args() | |
success = install_requirements(include_flash=args.flash) | |
if success: | |
logger.info("Installation completed successfully!") | |
else: | |
logger.error("Installation failed. Please check the logs for details.") | |
sys.exit(1) | |
if __name__ == "__main__": | |
main() |