File size: 1,345 Bytes
230c9a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import yaml
import warnings
from pdf_extract_kit.registry.registry import TASK_REGISTRY, MODEL_REGISTRY


def load_config(config_path):
    if config_path is None:
        warnings.warn(
            ("Configuration path is None. Please provide a valid configuration file path. ")
        )
        return None
    
    with open(config_path, 'r') as file:
        config = yaml.safe_load(file)
    return config


# def initialize_task_and_model(config):
#     task_name = config['task']
#     model_name = config['model']
#     model_config = config['model_config']

#     TaskClass = TASK_REGISTRY.get(task_name)
#     ModelClass = MODEL_REGISTRY.get(model_name)

#     model_instance = ModelClass(model_config)
#     task_instance = TaskClass(model_instance)

#     return task_instance

def initialize_tasks_and_models(config):

    task_instances = {}
    for task_name in config['tasks']:

        model_name = config['tasks'][task_name]['model']
        model_config = config['tasks'][task_name]['model_config']

        TaskClass = TASK_REGISTRY.get(task_name)
        ModelClass = MODEL_REGISTRY.get(model_name)

        model_instance = ModelClass(model_config)
        task_instance = TaskClass(model_instance)

        task_instances[task_name] = task_instance

    return task_instances