ParamDev's picture
Upload folder using huggingface_hub
a01ef8c verified
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2022 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
#
import click
import inspect
import os
import sys
from tlt.utils.types import FrameworkType
@click.command()
@click.option("--model-dir", "--model_dir",
required=True,
type=str,
help="Model directory to reload and evaluate a previously exported model.")
@click.option("--model-name", "--model_name",
required=False,
type=str,
help="Name of the model to evaluate. If a model name is not provided, the CLI will try to get the model "
"name from the model directory path. For example, if the model directory is /tmp/efficientnet_b0/10,"
" it will use 'efficientnet_b0' as the model name.")
@click.option("--dataset-dir", "--dataset_dir",
required=True,
type=str,
help="Dataset directory for a custom dataset, or if a dataset name "
"and catalog are being provided, the dataset directory is the "
"location where the dataset will be downloaded.")
@click.option("--dataset-file", "--dataset_file",
required=False,
type=str,
help="Name of a file in the dataset directory to load. Used for loading a .csv file for text "
"classification evaluation.")
@click.option("--delimiter",
required=False,
type=str,
default=",",
help="Delimiter used when loading a dataset from a csv file. [default: ,]")
@click.option("--class-names", "--class_names",
required=False,
type=str,
help="Comma separated string of class names for a text classification dataset being loaded from .csv")
@click.option("--dataset-name", "--dataset_name",
required=False,
type=str,
help="Name of the dataset to use from a dataset catalog.")
@click.option("--dataset-catalog", "--dataset_catalog",
required=False,
type=click.Choice(['tf_datasets', 'torchvision', 'huggingface']),
help="Name of a dataset catalog for a named dataset (Options: tf_datasets, torchvision, huggingface). "
"If a dataset name is provided and no dataset catalog is given, it will default to use "
"tf_datasets for a TensorFlow model, torchvision for PyTorch CV models, and huggingface datasets "
"for HuggingFace models.")
def eval(model_dir, model_name, dataset_dir, dataset_file, delimiter, class_names, dataset_name, dataset_catalog):
"""
Evaluates a model that has already been trained
"""
print("Model directory:", model_dir)
print("Dataset directory:", dataset_dir)
if dataset_file:
print("Dataset file:", dataset_file)
if class_names:
class_names = class_names.split(",")
print("Class names:", class_names)
if dataset_name:
print("Dataset name:", dataset_name)
if dataset_catalog:
print("Dataset catalog:", dataset_catalog)
try:
from tlt.utils.file_utils import verify_directory
verify_directory(model_dir, require_directory_exists=True)
except Exception as e:
sys.exit("Error while verifying the model directory: {}", str(e))
saved_model_path = os.path.join(model_dir, "saved_model.pb")
pytorch_model_path = os.path.join(model_dir, "model.pt")
if os.path.isfile(saved_model_path):
framework = FrameworkType.TENSORFLOW
model_path = saved_model_path
elif os.path.isfile(pytorch_model_path):
framework = FrameworkType.PYTORCH
model_path = pytorch_model_path
else:
sys.exit("Evaluation is currently only implemented for TensorFlow saved models and PyTorch .pt models. No such "
"files found in the model directory ({}).".format(model_dir))
if not model_name:
model_name = os.path.basename(os.path.dirname(model_dir))
print("Model name:", model_name)
print("Framework:", framework)
try:
from tlt.models.model_factory import get_model
print("Loading model object for {} using {}".format(model_name, str(framework)), flush=True)
model = get_model(model_name, framework)
print("Loading saved model from:", model_path)
model.load_from_directory(model_dir)
from tlt.datasets import dataset_factory
if not dataset_catalog and not dataset_name:
if str(model.use_case) == 'text_classification':
if not dataset_file:
raise ValueError("Loading a text classification dataset requires --dataset-file to specify the "
"file name of the .csv file to load from the --dataset-dir.")
if not class_names:
raise ValueError("Loading a text classification dataset requires --class-names to specify a list "
"of the class labels for the dataset.")
dataset = dataset_factory.load_dataset(dataset_dir, model.use_case, model.framework, dataset_name,
class_names=class_names, csv_file_name=dataset_file,
delimiter=delimiter)
else:
dataset = dataset_factory.load_dataset(dataset_dir, model.use_case, model.framework)
else:
dataset = dataset_factory.get_dataset(dataset_dir, model.use_case, model.framework, dataset_name,
dataset_catalog)
if 'image_size' in inspect.getfullargspec(dataset.preprocess).args:
dataset.preprocess(image_size=model.image_size, batch_size=32)
else:
dataset.preprocess(batch_size=32)
dataset.shuffle_split(seed=10)
model.evaluate(dataset)
except Exception as e:
sys.exit("An error occurred during evaluation: {}".format(str(e)))