File size: 2,619 Bytes
4562a06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import os
import signal
import subprocess
import sys
import time


def submitit_job_watcher(jobs, check_period: int = 15):
    job_out = {}

    try:
        while True:
            job_states = [job.state for job in jobs.values()]
            state_counts = {
                state: len([j for j in job_states if j == state])
                for state in set(job_states)
            }

            n_done = sum(job.done() for job in jobs.values())

            for job_name, job in jobs.items():
                if job_name not in job_out and job.done():
                    job_out[job_name] = {
                        "stderr": job.stderr(),
                        "stdout": job.stdout(),
                    }

                    exc = job.exception()
                    if exc is not None:
                        print(f"{job_name} crashed!!!")
                        if job_out[job_name]["stderr"] is not None:
                            print("===== STDERR =====")
                            print(job_out[job_name]["stderr"])
                    else:
                        print(f"{job_name} done!")

            print("Job states:")
            for state, count in state_counts.items():
                print(f"  {state:15s} {count:6d} ({100.*count/len(jobs):.1f}%)")

            if n_done == len(jobs):
                print("All done!")
                return

            time.sleep(check_period)

    except KeyboardInterrupt:
        for job_name, job in jobs.items():
            if not job.done():
                print(f"Killing {job_name}")
                job.cancel(check=False)


def get_jid():
    if "SLURM_ARRAY_TASK_ID" in os.environ:
        return f"{os.environ['SLURM_ARRAY_JOB_ID']}_{os.environ['SLURM_ARRAY_TASK_ID']}"
    return os.environ["SLURM_JOB_ID"]


def signal_helper(signum, frame):
    print(f"Caught signal {signal.Signals(signum).name} on for the this job")
    jid = get_jid()
    cmd = ["scontrol", "requeue", jid]
    try:
        print("calling", cmd)
        rtn = subprocess.check_call(cmd)
        print("subprocc", rtn)
    except:
        print("subproc call failed")
    return sys.exit(10)


def bypass(signum, frame):
    print(f"Ignoring signal {signal.Signals(signum).name} on for the this job")


def init_slurm_signals():
    signal.signal(signal.SIGCONT, bypass)
    signal.signal(signal.SIGCHLD, bypass)
    signal.signal(signal.SIGTERM, bypass)
    signal.signal(signal.SIGUSR2, signal_helper)
    print("SLURM signal installed", flush=True)


def init_slurm_signals_if_slurm():
    if "SLURM_JOB_ID" in os.environ:
        init_slurm_signals()