File size: 2,929 Bytes
a2fa160 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
import asyncio
import os
import signal
import sqlite3
from contextlib import asynccontextmanager
import psutil
from fastapi import FastAPI
from loguru import logger
from competitions.utils import run_evaluation
def get_process_status(pid):
try:
process = psutil.Process(pid)
proc_status = process.status()
return proc_status
except psutil.NoSuchProcess:
logger.info(f"No process found with PID: {pid}")
return "Completed"
def kill_process_by_pid(pid):
"""Kill process by PID."""
os.kill(pid, signal.SIGTERM)
class JobDB:
def __init__(self, db_path):
self.db_path = db_path
self.conn = sqlite3.connect(db_path)
self.c = self.conn.cursor()
self.create_jobs_table()
def create_jobs_table(self):
self.c.execute(
"""CREATE TABLE IF NOT EXISTS jobs
(id INTEGER PRIMARY KEY, pid INTEGER)"""
)
self.conn.commit()
def add_job(self, pid):
sql = f"INSERT INTO jobs (pid) VALUES ({pid})"
self.c.execute(sql)
self.conn.commit()
def get_running_jobs(self):
self.c.execute("""SELECT pid FROM jobs""")
running_pids = self.c.fetchall()
running_pids = [pid[0] for pid in running_pids]
return running_pids
def delete_job(self, pid):
sql = f"DELETE FROM jobs WHERE pid={pid}"
self.c.execute(sql)
self.conn.commit()
PARAMS = os.environ.get("PARAMS")
DB = JobDB("job.db")
class BackgroundRunner:
async def run_main(self):
while True:
running_jobs = DB.get_running_jobs()
if running_jobs:
for _pid in running_jobs:
proc_status = get_process_status(_pid)
proc_status = proc_status.strip().lower()
if proc_status in ("completed", "error", "zombie"):
logger.info(f"Process {_pid} is already completed. Skipping...")
try:
kill_process_by_pid(_pid)
except Exception as e:
logger.info(f"Error while killing process: {e}")
DB.delete_job(_pid)
running_jobs = DB.get_running_jobs()
if not running_jobs:
logger.info("No running jobs found. Shutting down the server.")
os.kill(os.getpid(), signal.SIGINT)
await asyncio.sleep(30)
runner = BackgroundRunner()
@asynccontextmanager
async def lifespan(app: FastAPI):
process_pid = run_evaluation(params=PARAMS)
logger.info(f"Started training with PID {process_pid}")
DB.add_job(process_pid)
asyncio.create_task(runner.run_main())
yield
api = FastAPI(lifespan=lifespan)
@api.get("/")
async def root():
return "Your model is being evaluated..."
@api.get("/health")
async def health():
return "OK"
|