Spaces:
Configuration error
Configuration error
# | |
# Copyright (c) 2022 Intel Corporation | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# | |
# SPDX-License-Identifier: Apache-2.0 | |
# | |
import yaml | |
import argparse | |
import os | |
import sys | |
from os import path | |
root_folder = path.dirname(path.abspath(__file__)) | |
sys.path.insert(0, path.join(root_folder, "../../../")) | |
print(sys.path) | |
from vision_wl import train_vision_wl, run_inference, collect_class_labels, load_model, run_inference_per_patient | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--config_file", type=str, required=True) | |
args = parser.parse_args() | |
with open(args.config_file, "r") as f: | |
config = yaml.safe_load(f) | |
dataset_dir = config['args']['dataset_dir'] | |
train_dataset_dir = os.path.join(dataset_dir, "train") | |
test_dataset_dir = os.path.join(dataset_dir, "test") | |
# output_dir = config['args']['output_dir'] | |
output_dir = config['training_args']['output_dir'] | |
batch_size = config['training_args']['batch_size'] | |
epochs = config['training_args']['epochs'] | |
bf16 = config['training_args']['bf16'] | |
model_name = config['args']['model'] | |
# do_predict = config['training_args']['do_predict'] | |
do_predict = config['args']['inference'] | |
do_predict_per_patient = config['args']['inference_per_patient'] | |
# do_train = config['training_args']['do_train'] | |
do_train = config['args']['finetune'] | |
saved_model_dir = config['args']['saved_model_dir'] | |
# output_file_test = config['args']['output_file_test_dir'] | |
output_file_test = config['args']['inference_output'] | |
# output_file_train = config['args']['output_file_train_dir'] | |
output_file_train = config['args']['finetune_output'] | |
# this is one is used for place holder | |
vision_int8_inference = 'vision_int_8.yaml' # config['inference_args']['int8_inference'] | |
class_labels = collect_class_labels(train_dataset_dir) | |
if (do_train): | |
model, history, dict_metrics = train_vision_wl(train_dataset_dir, | |
output_dir, model_name, | |
batch_size, epochs, bf16=bf16) | |
run_inference(train_dataset_dir, saved_model_dir, class_labels, | |
model_name, vision_int8_inference, output_file_train) | |
if (do_predict): | |
run_inference(test_dataset_dir, saved_model_dir, class_labels, | |
model_name, vision_int8_inference, output_file_test) | |
if (do_predict_per_patient): | |
model = load_model(model_name,saved_model_dir) | |
# Sample dict | |
patient_dict = {'106L':[os.path.join(train_dataset_dir,"Malignant/P106_L_CM_MLO1.jpg")],\ | |
'106R':[os.path.join(train_dataset_dir,"Benign/P106_R_CM_CC1.jpg"),\ | |
os.path.join(train_dataset_dir,"Benign/P106_R_CM_CC2.jpg")]} | |
results = run_inference_per_patient(model, patient_dict,class_labels) | |
if __name__ == "__main__": | |
main() | |