File size: 4,025 Bytes
1094cbb
 
34121ca
 
a2fa160
1094cbb
f5d4c87
 
1094cbb
 
 
 
 
 
 
 
 
 
 
 
 
a2fa160
 
 
 
 
1094cbb
 
a2fa160
34121ca
1094cbb
34121ca
a2fa160
 
 
 
 
 
 
0558a9f
 
34121ca
0558a9f
 
34121ca
 
0558a9f
 
34121ca
 
0558a9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf608be
 
a2fa160
23e18f2
 
 
a2fa160
d037a92
cf608be
d037a92
 
 
 
 
 
 
1094cbb
 
 
 
34121ca
1094cbb
 
 
 
 
7dbcdbe
f5d4c87
 
 
 
 
 
 
 
 
 
 
 
 
 
66d595e
1094cbb
 
a2fa160
1094cbb
a2fa160
1094cbb
de15d44
1094cbb
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import argparse
import json
import os
import shutil
import subprocess

from huggingface_hub import HfApi, Repository, hf_hub_download, snapshot_download
from huggingface_hub.utils._errors import EntryNotFoundError
from loguru import logger

from competitions import utils
from competitions.compute_metrics import compute_metrics
from competitions.params import EvalParams


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, required=True)
    return parser.parse_args()


def upload_submission_file(params, file_path):
    logger.info("Uploading submission file")
    pass


def generate_submission_file(params):
    logger.info("Downloading submission dataset")
    submission_dir = snapshot_download(
        repo_id=params.submission_repo,
        local_dir=params.output_path,
        token=os.environ.get("USER_TOKEN"),
        repo_type="model",
    )
    # submission_dir has a script.py file
    # start a subprocess to run the script.py
    # the script.py will generate a submission.csv file in the submission_dir
    # push the submission.csv file to the repo using upload_submission_file
    logger.info("Generating submission file")

    # Copy socket-kit.so to submission_dir
    shutil.copyfile("socket-kit.so", f"{submission_dir}/socket-kit.so")

    # Define your command
    cmd = "python script.py"
    socket_kit_path = os.path.abspath(f"{submission_dir}/socket-kit.so")

    # Copy the current environment and modify it
    env = os.environ.copy()
    env["LD_PRELOAD"] = socket_kit_path

    # Start the subprocess
    process = subprocess.Popen(cmd, cwd=submission_dir, shell=True, env=env)

    # Wait for the process to complete or timeout
    try:
        process.wait(timeout=params.time_limit)
    except subprocess.TimeoutExpired:
        logger.info(f"Process exceeded {params.time_limit} seconds time limit. Terminating...")
        process.kill()
        process.wait()

    # Check if process terminated due to timeout
    if process.returncode and process.returncode != 0:
        logger.error("Subprocess didn't terminate successfully")
    else:
        logger.info("Subprocess terminated successfully")

    logger.info("contents of submission_dir")
    logger.info(os.listdir(submission_dir))

    api = HfApi(token=params.token)
    for sub_file in params.submission_filenames:
        logger.info(f"Uploading {sub_file} to the repository")
        sub_file_ext = sub_file.split(".")[-1]
        api.upload_file(
            path_or_fileobj=f"{submission_dir}/{sub_file}",
            path_in_repo=f"submissions/{params.team_id}-{params.submission_id}.{sub_file_ext}",
            repo_id=params.competition_id,
            repo_type="dataset",
        )


@utils.monitor
def run(params):
    logger.info(params)
    if isinstance(params, dict):
        params = EvalParams(**params)

    utils.update_submission_status(params, "processing")

    if params.competition_type == "script":
        try:
            requirements_fname = hf_hub_download(
                repo_id=params.competition_id,
                filename="requirements.txt",
                token=params.token,
                repo_type="dataset",
            )
        except EntryNotFoundError:
            requirements_fname = None

        if requirements_fname:
            logger.info("Installing requirements")
            utils.uninstall_requirements(requirements_fname)
            utils.install_requirements(requirements_fname)
        _ = Repository(local_dir="/tmp/data", clone_from=params.dataset, token=params.token)
        generate_submission_file(params)

    evaluation = compute_metrics(params)

    utils.update_submission_score(params, evaluation["public_score"], evaluation["private_score"])
    utils.update_submission_status(params, "success")
    utils.delete_space(params)


if __name__ == "__main__":
    args = parse_args()
    _params = json.load(open(args.config, encoding="utf-8"))
    _params = EvalParams(**_params)
    run(_params)