File size: 2,929 Bytes
a2fa160
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import asyncio
import os
import signal
import sqlite3
from contextlib import asynccontextmanager

import psutil
from fastapi import FastAPI
from loguru import logger

from competitions.utils import run_evaluation


def get_process_status(pid):
    try:
        process = psutil.Process(pid)
        proc_status = process.status()
        return proc_status
    except psutil.NoSuchProcess:
        logger.info(f"No process found with PID: {pid}")
        return "Completed"


def kill_process_by_pid(pid):
    """Kill process by PID."""
    os.kill(pid, signal.SIGTERM)


class JobDB:
    def __init__(self, db_path):
        self.db_path = db_path
        self.conn = sqlite3.connect(db_path)
        self.c = self.conn.cursor()
        self.create_jobs_table()

    def create_jobs_table(self):
        self.c.execute(
            """CREATE TABLE IF NOT EXISTS jobs
            (id INTEGER PRIMARY KEY, pid INTEGER)"""
        )
        self.conn.commit()

    def add_job(self, pid):
        sql = f"INSERT INTO jobs (pid) VALUES ({pid})"
        self.c.execute(sql)
        self.conn.commit()

    def get_running_jobs(self):
        self.c.execute("""SELECT pid FROM jobs""")
        running_pids = self.c.fetchall()
        running_pids = [pid[0] for pid in running_pids]
        return running_pids

    def delete_job(self, pid):
        sql = f"DELETE FROM jobs WHERE pid={pid}"
        self.c.execute(sql)
        self.conn.commit()


PARAMS = os.environ.get("PARAMS")
DB = JobDB("job.db")


class BackgroundRunner:
    async def run_main(self):
        while True:
            running_jobs = DB.get_running_jobs()
            if running_jobs:
                for _pid in running_jobs:
                    proc_status = get_process_status(_pid)
                    proc_status = proc_status.strip().lower()
                    if proc_status in ("completed", "error", "zombie"):
                        logger.info(f"Process {_pid} is already completed. Skipping...")
                        try:
                            kill_process_by_pid(_pid)
                        except Exception as e:
                            logger.info(f"Error while killing process: {e}")
                        DB.delete_job(_pid)

            running_jobs = DB.get_running_jobs()
            if not running_jobs:
                logger.info("No running jobs found. Shutting down the server.")
                os.kill(os.getpid(), signal.SIGINT)
            await asyncio.sleep(30)


runner = BackgroundRunner()


@asynccontextmanager
async def lifespan(app: FastAPI):
    process_pid = run_evaluation(params=PARAMS)
    logger.info(f"Started training with PID {process_pid}")
    DB.add_job(process_pid)
    asyncio.create_task(runner.run_main())
    yield


api = FastAPI(lifespan=lifespan)


@api.get("/")
async def root():
    return "Your model is being evaluated..."


@api.get("/health")
async def health():
    return "OK"