File size: 2,983 Bytes
33d4721
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import asyncio
import os
import signal
import sys
from contextlib import asynccontextmanager

from fastapi import FastAPI

from autotrain import logger
from autotrain.app.db import AutoTrainDB
from autotrain.app.utils import get_running_jobs, kill_process_by_pid
from autotrain.utils import run_training


HF_TOKEN = os.environ.get("HF_TOKEN")
AUTOTRAIN_USERNAME = os.environ.get("AUTOTRAIN_USERNAME")
PROJECT_NAME = os.environ.get("PROJECT_NAME")
TASK_ID = int(os.environ.get("TASK_ID"))
PARAMS = os.environ.get("PARAMS")
DATA_PATH = os.environ.get("DATA_PATH")
MODEL = os.environ.get("MODEL")
DB = AutoTrainDB("autotrain.db")


def graceful_exit(signum, frame):
    """
    Handles the SIGTERM signal to perform cleanup and exit the program gracefully.

    Args:
        signum (int): The signal number.
        frame (FrameType): The current stack frame (or None).

    Logs a message indicating that SIGTERM was received and then exits the program with status code 0.
    """
    logger.info("SIGTERM received. Performing cleanup...")
    sys.exit(0)


signal.signal(signal.SIGTERM, graceful_exit)


class BackgroundRunner:
    """
    A class to handle background running tasks.

    Methods
    -------
    run_main():
        Continuously checks for running jobs and shuts down the server if no jobs are found.
    """

    async def run_main(self):
        while True:
            running_jobs = get_running_jobs(DB)
            if not running_jobs:
                logger.info("No running jobs found. Shutting down the server.")
                kill_process_by_pid(os.getpid())
            await asyncio.sleep(30)


runner = BackgroundRunner()


@asynccontextmanager
async def lifespan(app: FastAPI):
    """
    Manages the lifespan of the FastAPI application.

    This function is responsible for starting the training process and
    managing a background task runner. It logs the process ID of the
    training job, adds the job to the database, and ensures the background
    task is properly cancelled when the application shuts down.

    Args:
        app (FastAPI): The FastAPI application instance.

    Yields:
        None: This function is a generator that yields control back to the
        FastAPI application lifecycle.
    """
    process_pid = run_training(params=PARAMS, task_id=TASK_ID)
    logger.info(f"Started training with PID {process_pid}")
    DB.add_job(process_pid)
    task = asyncio.create_task(runner.run_main())
    yield

    task.cancel()
    try:
        await task
    except asyncio.CancelledError:
        logger.info("Background runner task cancelled.")


api = FastAPI(lifespan=lifespan)
logger.info(f"AUTOTRAIN_USERNAME: {AUTOTRAIN_USERNAME}")
logger.info(f"PROJECT_NAME: {PROJECT_NAME}")
logger.info(f"TASK_ID: {TASK_ID}")
logger.info(f"DATA_PATH: {DATA_PATH}")
logger.info(f"MODEL: {MODEL}")


@api.get("/")
async def root():
    return "Your model is being trained..."


@api.get("/health")
async def health():
    return "OK"