hardiktiwari's picture
Upload 244 files
33d4721 verified
import asyncio
import os
import signal
import sys
from contextlib import asynccontextmanager
from fastapi import FastAPI
from autotrain import logger
from autotrain.app.db import AutoTrainDB
from autotrain.app.utils import get_running_jobs, kill_process_by_pid
from autotrain.utils import run_training
HF_TOKEN = os.environ.get("HF_TOKEN")
AUTOTRAIN_USERNAME = os.environ.get("AUTOTRAIN_USERNAME")
PROJECT_NAME = os.environ.get("PROJECT_NAME")
TASK_ID = int(os.environ.get("TASK_ID"))
PARAMS = os.environ.get("PARAMS")
DATA_PATH = os.environ.get("DATA_PATH")
MODEL = os.environ.get("MODEL")
DB = AutoTrainDB("autotrain.db")
def graceful_exit(signum, frame):
"""
Handles the SIGTERM signal to perform cleanup and exit the program gracefully.
Args:
signum (int): The signal number.
frame (FrameType): The current stack frame (or None).
Logs a message indicating that SIGTERM was received and then exits the program with status code 0.
"""
logger.info("SIGTERM received. Performing cleanup...")
sys.exit(0)
signal.signal(signal.SIGTERM, graceful_exit)
class BackgroundRunner:
"""
A class to handle background running tasks.
Methods
-------
run_main():
Continuously checks for running jobs and shuts down the server if no jobs are found.
"""
async def run_main(self):
while True:
running_jobs = get_running_jobs(DB)
if not running_jobs:
logger.info("No running jobs found. Shutting down the server.")
kill_process_by_pid(os.getpid())
await asyncio.sleep(30)
runner = BackgroundRunner()
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
Manages the lifespan of the FastAPI application.
This function is responsible for starting the training process and
managing a background task runner. It logs the process ID of the
training job, adds the job to the database, and ensures the background
task is properly cancelled when the application shuts down.
Args:
app (FastAPI): The FastAPI application instance.
Yields:
None: This function is a generator that yields control back to the
FastAPI application lifecycle.
"""
process_pid = run_training(params=PARAMS, task_id=TASK_ID)
logger.info(f"Started training with PID {process_pid}")
DB.add_job(process_pid)
task = asyncio.create_task(runner.run_main())
yield
task.cancel()
try:
await task
except asyncio.CancelledError:
logger.info("Background runner task cancelled.")
api = FastAPI(lifespan=lifespan)
logger.info(f"AUTOTRAIN_USERNAME: {AUTOTRAIN_USERNAME}")
logger.info(f"PROJECT_NAME: {PROJECT_NAME}")
logger.info(f"TASK_ID: {TASK_ID}")
logger.info(f"DATA_PATH: {DATA_PATH}")
logger.info(f"MODEL: {MODEL}")
@api.get("/")
async def root():
return "Your model is being trained..."
@api.get("/health")
async def health():
return "OK"