Abhishek Thakur
migrate to teams
a2fa160
import asyncio
import os
import signal
import sqlite3
from contextlib import asynccontextmanager
import psutil
from fastapi import FastAPI
from loguru import logger
from competitions.utils import run_evaluation
def get_process_status(pid):
try:
process = psutil.Process(pid)
proc_status = process.status()
return proc_status
except psutil.NoSuchProcess:
logger.info(f"No process found with PID: {pid}")
return "Completed"
def kill_process_by_pid(pid):
"""Kill process by PID."""
os.kill(pid, signal.SIGTERM)
class JobDB:
def __init__(self, db_path):
self.db_path = db_path
self.conn = sqlite3.connect(db_path)
self.c = self.conn.cursor()
self.create_jobs_table()
def create_jobs_table(self):
self.c.execute(
"""CREATE TABLE IF NOT EXISTS jobs
(id INTEGER PRIMARY KEY, pid INTEGER)"""
)
self.conn.commit()
def add_job(self, pid):
sql = f"INSERT INTO jobs (pid) VALUES ({pid})"
self.c.execute(sql)
self.conn.commit()
def get_running_jobs(self):
self.c.execute("""SELECT pid FROM jobs""")
running_pids = self.c.fetchall()
running_pids = [pid[0] for pid in running_pids]
return running_pids
def delete_job(self, pid):
sql = f"DELETE FROM jobs WHERE pid={pid}"
self.c.execute(sql)
self.conn.commit()
PARAMS = os.environ.get("PARAMS")
DB = JobDB("job.db")
class BackgroundRunner:
async def run_main(self):
while True:
running_jobs = DB.get_running_jobs()
if running_jobs:
for _pid in running_jobs:
proc_status = get_process_status(_pid)
proc_status = proc_status.strip().lower()
if proc_status in ("completed", "error", "zombie"):
logger.info(f"Process {_pid} is already completed. Skipping...")
try:
kill_process_by_pid(_pid)
except Exception as e:
logger.info(f"Error while killing process: {e}")
DB.delete_job(_pid)
running_jobs = DB.get_running_jobs()
if not running_jobs:
logger.info("No running jobs found. Shutting down the server.")
os.kill(os.getpid(), signal.SIGINT)
await asyncio.sleep(30)
runner = BackgroundRunner()
@asynccontextmanager
async def lifespan(app: FastAPI):
process_pid = run_evaluation(params=PARAMS)
logger.info(f"Started training with PID {process_pid}")
DB.add_job(process_pid)
asyncio.create_task(runner.run_main())
yield
api = FastAPI(lifespan=lifespan)
@api.get("/")
async def root():
return "Your model is being evaluated..."
@api.get("/health")
async def health():
return "OK"