File size: 3,273 Bytes
4c2f8ad
1768d92
13e89d2
 
1768d92
892d25a
1768d92
 
5291e0b
4591bfb
1768d92
 
4c2f8ad
d7d8b33
4c2f8ad
 
d7d8b33
 
 
 
 
 
 
 
 
 
9a6adf1
4c2f8ad
 
 
 
 
 
 
 
 
 
 
 
 
 
1768d92
 
892d25a
4c2f8ad
892d25a
 
 
 
 
 
 
 
 
 
 
 
 
d7d8b33
4c2f8ad
d7d8b33
 
 
 
 
678531b
0f3dbb0
d7d8b33
 
 
 
 
 
 
678531b
 
 
 
 
 
 
892d25a
678531b
50018bb
0f3dbb0
50018bb
 
 
678531b
892d25a
 
 
 
 
 
 
efcff68
1768d92
 
 
 
 
 
 
 
 
 
efcff68
13e89d2
4591bfb
13e89d2
 
 
4591bfb
efcff68
4591bfb
5291e0b
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import json
import os
import sys

import uvicorn
import yaml  # type: ignore
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
from zeno import ZenoParameters, get_server, zeno  # type: ignore
from zeno_evals import ZenoEvals  # type: ignore


# parse information in spec
def prepare_spec(file_path):
    res = {}
    data = []
    accuracy = 0
    with open(file_path) as f:
        for line in f:
            json_entry = json.loads(line)
            if "final_report" in json_entry:
                accuracy = json_entry["final_report"]["accuracy"]
            data.append(json_entry)

    res["models"] = data[0]["spec"]["completion_fns"][0]
    res["accuracy"] = accuracy * 100
    res["events"] = (len(data) - 2) / 2
    return res


def prepare_zeno_params(config: ZenoParameters):
    res = {}
    res["models"] = config.models
    res["view"] = config.view
    res["data_column"] = config.data_column
    res["id_column"] = config.id_column
    res["batch_size"] = config.batch_size
    res["samples"] = config.samples
    return res


def command_line():
    app = FastAPI(title="Frontend API")
    args = []

    with open(sys.argv[1], "r") as f:
        args = yaml.safe_load(f)

    @app.get("/args")
    def get_args():
        return args

    os.chdir(os.path.dirname(sys.argv[1]))

    zeno_objs = []
    for entry in args:
        name = list(entry.keys())[0]
        params = entry[name]

        second_exists = True if "second-results-file" in params else False

        res_spec = prepare_spec(params["results-file"])
        params["models"] = [res_spec["models"]]
        params["accuracy"] = [res_spec["accuracy"]]
        params["events"] = [res_spec["events"]]
        params["link"] = [params["link"]]
        params["description"] = [params["description"]]

        if second_exists:
            sec_res_spec = prepare_spec(params["second-results-file"])
            params["models"].append(sec_res_spec["models"])
            params["accuracy"].append(sec_res_spec["accuracy"])
            params["events"].append(sec_res_spec["events"])

        zeno_eval = ZenoEvals(
            params.get("results-file"),
            params.get("second-results-file"),
            params.get("functions-file"),
        )
        config = zeno_eval.generate_zeno_config()

        config.serve = False
        config.cache_path = "./.zeno_cache_" + name
        config.multiprocessing = False
        config.batch_size = 2000
        port_arg = os.getenv("PORT")
        if port_arg is not None:
            config.editable = False

        zeno_obj = zeno(config)
        if zeno_obj is None:
            sys.exit(1)
        server = get_server(zeno_obj)
        zeno_obj.start_processing()
        zeno_objs.append(zeno_obj)
        app.mount("/" + name, server)

    app.mount(
        "/",
        StaticFiles(
            directory=os.path.dirname(os.path.realpath(__file__)) + "/frontend",
            html=True,
        ),
        name="base",
    )

    print("Running server")

    port = 8000
    host = "localhost"
    port_arg = os.getenv("PORT")
    if port_arg is not None:
        port = int(port_arg)
        host = "0.0.0.0"

    uvicorn.run(app, host=host, port=port)


if __name__ == "__main__":
    command_line()