|
|
|
import io |
|
import json |
|
import logging |
|
import os |
|
from urllib.parse import urlparse |
|
|
|
import boto3 |
|
from botocore.exceptions import ClientError |
|
from label_studio_ml.model import LabelStudioMLBase |
|
from label_studio_ml.utils import (DATA_UNDEFINED_NAME, get_image_size, |
|
get_single_tag_keys) |
|
from label_studio_tools.core.utils.io import get_data_dir |
|
|
|
from mmdet.apis import inference_detector, init_detector |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class MMDetection(LabelStudioMLBase): |
|
"""Object detector based on https://github.com/open-mmlab/mmdetection.""" |
|
|
|
def __init__(self, |
|
config_file=None, |
|
checkpoint_file=None, |
|
image_dir=None, |
|
labels_file=None, |
|
score_threshold=0.5, |
|
device='cpu', |
|
**kwargs): |
|
|
|
super(MMDetection, self).__init__(**kwargs) |
|
config_file = config_file or os.environ['config_file'] |
|
checkpoint_file = checkpoint_file or os.environ['checkpoint_file'] |
|
self.config_file = config_file |
|
self.checkpoint_file = checkpoint_file |
|
self.labels_file = labels_file |
|
|
|
upload_dir = os.path.join(get_data_dir(), 'media', 'upload') |
|
self.image_dir = image_dir or upload_dir |
|
logger.debug( |
|
f'{self.__class__.__name__} reads images from {self.image_dir}') |
|
if self.labels_file and os.path.exists(self.labels_file): |
|
self.label_map = json_load(self.labels_file) |
|
else: |
|
self.label_map = {} |
|
|
|
self.from_name, self.to_name, self.value, self.labels_in_config = get_single_tag_keys( |
|
self.parsed_label_config, 'RectangleLabels', 'Image') |
|
schema = list(self.parsed_label_config.values())[0] |
|
self.labels_in_config = set(self.labels_in_config) |
|
|
|
|
|
self.labels_attrs = schema.get('labels_attrs') |
|
if self.labels_attrs: |
|
for label_name, label_attrs in self.labels_attrs.items(): |
|
for predicted_value in label_attrs.get('predicted_values', |
|
'').split(','): |
|
self.label_map[predicted_value] = label_name |
|
|
|
print('Load new model from: ', config_file, checkpoint_file) |
|
self.model = init_detector(config_file, checkpoint_file, device=device) |
|
self.score_thresh = score_threshold |
|
|
|
def _get_image_url(self, task): |
|
image_url = task['data'].get( |
|
self.value) or task['data'].get(DATA_UNDEFINED_NAME) |
|
if image_url.startswith('s3://'): |
|
|
|
r = urlparse(image_url, allow_fragments=False) |
|
bucket_name = r.netloc |
|
key = r.path.lstrip('/') |
|
client = boto3.client('s3') |
|
try: |
|
image_url = client.generate_presigned_url( |
|
ClientMethod='get_object', |
|
Params={ |
|
'Bucket': bucket_name, |
|
'Key': key |
|
}) |
|
except ClientError as exc: |
|
logger.warning( |
|
f'Can\'t generate presigned URL for {image_url}. Reason: {exc}' |
|
) |
|
return image_url |
|
|
|
def predict(self, tasks, **kwargs): |
|
assert len(tasks) == 1 |
|
task = tasks[0] |
|
image_url = self._get_image_url(task) |
|
image_path = self.get_local_path(image_url) |
|
model_results = inference_detector(self.model, |
|
image_path).pred_instances |
|
results = [] |
|
all_scores = [] |
|
img_width, img_height = get_image_size(image_path) |
|
print(f'>>> model_results: {model_results}') |
|
print(f'>>> label_map {self.label_map}') |
|
print(f'>>> self.model.dataset_meta: {self.model.dataset_meta}') |
|
classes = self.model.dataset_meta.get('classes') |
|
print(f'Classes >>> {classes}') |
|
for item in model_results: |
|
print(f'item >>>>> {item}') |
|
bboxes, label, scores = item['bboxes'], item['labels'], item[ |
|
'scores'] |
|
score = float(scores[-1]) |
|
if score < self.score_thresh: |
|
continue |
|
print(f'bboxes >>>>> {bboxes}') |
|
print(f'label >>>>> {label}') |
|
output_label = classes[list(self.label_map.get(label, label))[0]] |
|
print(f'>>> output_label: {output_label}') |
|
if output_label not in self.labels_in_config: |
|
print(output_label + ' label not found in project config.') |
|
continue |
|
|
|
for bbox in bboxes: |
|
bbox = list(bbox) |
|
if not bbox: |
|
continue |
|
|
|
x, y, xmax, ymax = bbox[:4] |
|
results.append({ |
|
'from_name': self.from_name, |
|
'to_name': self.to_name, |
|
'type': 'rectanglelabels', |
|
'value': { |
|
'rectanglelabels': [output_label], |
|
'x': float(x) / img_width * 100, |
|
'y': float(y) / img_height * 100, |
|
'width': (float(xmax) - float(x)) / img_width * 100, |
|
'height': (float(ymax) - float(y)) / img_height * 100 |
|
}, |
|
'score': score |
|
}) |
|
all_scores.append(score) |
|
avg_score = sum(all_scores) / max(len(all_scores), 1) |
|
print(f'>>> RESULTS: {results}') |
|
return [{'result': results, 'score': avg_score}] |
|
|
|
|
|
def json_load(file, int_keys=False): |
|
with io.open(file, encoding='utf8') as f: |
|
data = json.load(f) |
|
if int_keys: |
|
return {int(k): v for k, v in data.items()} |
|
else: |
|
return data |
|
|