Spaces:
Configuration error
Configuration error
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
# | |
# Copyright (c) 2022 Intel Corporation | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# | |
# SPDX-License-Identifier: Apache-2.0 | |
# | |
import click | |
import inspect | |
import os | |
import sys | |
from tlt.utils.types import FrameworkType | |
def eval(model_dir, model_name, dataset_dir, dataset_file, delimiter, class_names, dataset_name, dataset_catalog): | |
""" | |
Evaluates a model that has already been trained | |
""" | |
print("Model directory:", model_dir) | |
print("Dataset directory:", dataset_dir) | |
if dataset_file: | |
print("Dataset file:", dataset_file) | |
if class_names: | |
class_names = class_names.split(",") | |
print("Class names:", class_names) | |
if dataset_name: | |
print("Dataset name:", dataset_name) | |
if dataset_catalog: | |
print("Dataset catalog:", dataset_catalog) | |
try: | |
from tlt.utils.file_utils import verify_directory | |
verify_directory(model_dir, require_directory_exists=True) | |
except Exception as e: | |
sys.exit("Error while verifying the model directory: {}", str(e)) | |
saved_model_path = os.path.join(model_dir, "saved_model.pb") | |
pytorch_model_path = os.path.join(model_dir, "model.pt") | |
if os.path.isfile(saved_model_path): | |
framework = FrameworkType.TENSORFLOW | |
model_path = saved_model_path | |
elif os.path.isfile(pytorch_model_path): | |
framework = FrameworkType.PYTORCH | |
model_path = pytorch_model_path | |
else: | |
sys.exit("Evaluation is currently only implemented for TensorFlow saved models and PyTorch .pt models. No such " | |
"files found in the model directory ({}).".format(model_dir)) | |
if not model_name: | |
model_name = os.path.basename(os.path.dirname(model_dir)) | |
print("Model name:", model_name) | |
print("Framework:", framework) | |
try: | |
from tlt.models.model_factory import get_model | |
print("Loading model object for {} using {}".format(model_name, str(framework)), flush=True) | |
model = get_model(model_name, framework) | |
print("Loading saved model from:", model_path) | |
model.load_from_directory(model_dir) | |
from tlt.datasets import dataset_factory | |
if not dataset_catalog and not dataset_name: | |
if str(model.use_case) == 'text_classification': | |
if not dataset_file: | |
raise ValueError("Loading a text classification dataset requires --dataset-file to specify the " | |
"file name of the .csv file to load from the --dataset-dir.") | |
if not class_names: | |
raise ValueError("Loading a text classification dataset requires --class-names to specify a list " | |
"of the class labels for the dataset.") | |
dataset = dataset_factory.load_dataset(dataset_dir, model.use_case, model.framework, dataset_name, | |
class_names=class_names, csv_file_name=dataset_file, | |
delimiter=delimiter) | |
else: | |
dataset = dataset_factory.load_dataset(dataset_dir, model.use_case, model.framework) | |
else: | |
dataset = dataset_factory.get_dataset(dataset_dir, model.use_case, model.framework, dataset_name, | |
dataset_catalog) | |
if 'image_size' in inspect.getfullargspec(dataset.preprocess).args: | |
dataset.preprocess(image_size=model.image_size, batch_size=32) | |
else: | |
dataset.preprocess(batch_size=32) | |
dataset.shuffle_split(seed=10) | |
model.evaluate(dataset) | |
except Exception as e: | |
sys.exit("An error occurred during evaluation: {}".format(str(e))) | |