import asyncio import os import signal import sqlite3 from contextlib import asynccontextmanager import psutil from fastapi import FastAPI from loguru import logger from competitions.utils import run_evaluation def get_process_status(pid): try: process = psutil.Process(pid) proc_status = process.status() return proc_status except psutil.NoSuchProcess: logger.info(f"No process found with PID: {pid}") return "Completed" def kill_process_by_pid(pid): """Kill process by PID.""" os.kill(pid, signal.SIGTERM) class JobDB: def __init__(self, db_path): self.db_path = db_path self.conn = sqlite3.connect(db_path) self.c = self.conn.cursor() self.create_jobs_table() def create_jobs_table(self): self.c.execute( """CREATE TABLE IF NOT EXISTS jobs (id INTEGER PRIMARY KEY, pid INTEGER)""" ) self.conn.commit() def add_job(self, pid): sql = f"INSERT INTO jobs (pid) VALUES ({pid})" self.c.execute(sql) self.conn.commit() def get_running_jobs(self): self.c.execute("""SELECT pid FROM jobs""") running_pids = self.c.fetchall() running_pids = [pid[0] for pid in running_pids] return running_pids def delete_job(self, pid): sql = f"DELETE FROM jobs WHERE pid={pid}" self.c.execute(sql) self.conn.commit() PARAMS = os.environ.get("PARAMS") DB = JobDB("job.db") class BackgroundRunner: async def run_main(self): while True: running_jobs = DB.get_running_jobs() if running_jobs: for _pid in running_jobs: proc_status = get_process_status(_pid) proc_status = proc_status.strip().lower() if proc_status in ("completed", "error", "zombie"): logger.info(f"Process {_pid} is already completed. Skipping...") try: kill_process_by_pid(_pid) except Exception as e: logger.info(f"Error while killing process: {e}") DB.delete_job(_pid) running_jobs = DB.get_running_jobs() if not running_jobs: logger.info("No running jobs found. Shutting down the server.") os.kill(os.getpid(), signal.SIGINT) await asyncio.sleep(30) runner = BackgroundRunner() @asynccontextmanager async def lifespan(app: FastAPI): process_pid = run_evaluation(params=PARAMS) logger.info(f"Started training with PID {process_pid}") DB.add_job(process_pid) asyncio.create_task(runner.run_main()) yield api = FastAPI(lifespan=lifespan) @api.get("/") async def root(): return "Your model is being evaluated..." @api.get("/health") async def health(): return "OK"