File size: 2,430 Bytes
32e9f89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bbd5ba9
32e9f89
bbd5ba9
32e9f89
bbd5ba9
 
32e9f89
 
bbd5ba9
32e9f89
 
bbd5ba9
 
 
32e9f89
bbd5ba9
32e9f89
bbd5ba9
 
 
 
32e9f89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
#!/usr/bin/env python
# coding=utf-8

"""

Script to install requirements in the correct order for the Phi-4 training project.

This ensures base requirements are installed first, followed by additional requirements.

"""

import os
import sys
import subprocess
import argparse
import logging
from pathlib import Path

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
    handlers=[logging.StreamHandler(sys.stdout)]
)
logger = logging.getLogger(__name__)

def install_requirements(include_flash=False):
    """Install requirements from the consolidated requirements file."""
    current_dir = Path(__file__).parent
    req_path = current_dir / "requirements.txt"
    
    if not req_path.exists():
        logger.error(f"Requirements file not found: {req_path}")
        return False
    
    logger.info("Installing dependencies from consolidated requirements file...")
    
    try:
        # Install all requirements
        logger.info(f"Installing requirements from {req_path}")
        subprocess.run([sys.executable, "-m", "pip", "install", "-r", str(req_path)], 
                      check=True)
        logger.info("Main requirements installed successfully")
        
        # Optionally install flash-attention
        if include_flash:
            logger.info("Installing flash-attention...")
            subprocess.run([sys.executable, "-m", "pip", "install", "flash-attn==2.5.2", "--no-build-isolation"], 
                          check=True)
            logger.info("Flash-attention installed successfully")
        
        logger.info("All required packages installed successfully!")
        return True
    
    except subprocess.CalledProcessError as e:
        logger.error(f"Error installing dependencies: {str(e)}")
        return False

def main():
    parser = argparse.ArgumentParser(description="Install requirements for Phi-4 training")
    parser.add_argument("--flash", action="store_true", help="Also install flash-attention (optional)")
    args = parser.parse_args()
    
    success = install_requirements(include_flash=args.flash)
    if success:
        logger.info("Installation completed successfully!")
    else:
        logger.error("Installation failed. Please check the logs for details.")
        sys.exit(1)

if __name__ == "__main__":
    main()