Spaces:
Sleeping
Sleeping
import json | |
import os.path as osp | |
import numpy as np | |
from mmcv.transforms import BaseTransform | |
from mmcv.transforms import LoadImageFromFile | |
from mmdet.registry import DATASETS, TRANSFORMS | |
from mmdet.datasets.transforms import PackDetInputs | |
from mmdet.datasets.base_det_dataset import BaseDetDataset | |
import warnings | |
# βββ Enhanced robust image loader for real images βββ | |
class RobustLoadImageFromFile(LoadImageFromFile): | |
"""Enhanced image loader: tries real images first, falls back to dummy if needed.""" | |
# Class variable to track missing images | |
missing_count = 0 | |
def __init__(self, try_real_images=True, fallback_to_dummy=True, **kwargs): | |
super().__init__(**kwargs) | |
self.try_real_images = try_real_images | |
self.fallback_to_dummy = fallback_to_dummy | |
def transform(self, results): | |
"""Try to load real image first, fall back to dummy if not found.""" | |
if self.try_real_images: | |
try: | |
# Try standard MMDet image loading first | |
results = super().transform(results) | |
return results | |
except (FileNotFoundError, OSError, Exception) as e: | |
# Count missing image | |
RobustLoadImageFromFile.missing_count += 1 | |
# Log warning every 10 missing images to avoid spam | |
if RobustLoadImageFromFile.missing_count % 10 == 1: | |
warnings.warn(f"Missing image #{RobustLoadImageFromFile.missing_count}: {results.get('img_path', 'unknown')}. " | |
f"Total missing so far: {RobustLoadImageFromFile.missing_count}", | |
UserWarning) | |
if not self.fallback_to_dummy: | |
raise e | |
# Fall through to create dummy image | |
# Create dummy image (either by choice or because real image loading failed) | |
if 'img_shape' in results: | |
h, w = results['img_shape'][:2] | |
else: | |
h = results.get('height', 800) | |
w = results.get('width', 600) | |
results['img'] = np.zeros((h, w, 3), dtype=np.uint8) | |
results['img_shape'] = (h, w, 3) | |
results['ori_shape'] = (h, w, 3) | |
return results | |
def get_missing_count(cls): | |
"""Get the total count of missing images.""" | |
return cls.missing_count | |
def reset_missing_count(cls): | |
"""Reset the missing image counter.""" | |
cls.missing_count = 0 | |
# βββ Legacy support for old transform name βββ | |
class CreateDummyImg(RobustLoadImageFromFile): | |
"""Legacy alias for RobustLoadImageFromFile.""" | |
pass | |
class ClampBBoxes(BaseTransform): | |
"""Simple bbox clamping transform - only clamps coordinates, doesn't filter.""" | |
def __init__(self, min_size=1): | |
self.min_size = min_size | |
def transform(self, results): | |
"""Clamp bboxes to image bounds without removing any boxes.""" | |
if 'gt_bboxes' not in results: | |
return results | |
h, w = results['img_shape'][:2] | |
# Handle both numpy arrays and MMDet's HorizontalBoxes objects | |
gt_bboxes = results['gt_bboxes'] | |
if hasattr(gt_bboxes, 'tensor'): | |
# MMDet HorizontalBoxes object - clamp in place | |
gt_bboxes.tensor[:, 0].clamp_(0, w) # x1 | |
gt_bboxes.tensor[:, 1].clamp_(0, h) # y1 | |
gt_bboxes.tensor[:, 2].clamp_(0, w) # x2 | |
gt_bboxes.tensor[:, 3].clamp_(0, h) # y2 | |
else: | |
# Regular numpy array - clamp in place | |
if len(gt_bboxes) > 0: | |
gt_bboxes[:, 0] = np.clip(gt_bboxes[:, 0], 0, w) # x1 | |
gt_bboxes[:, 1] = np.clip(gt_bboxes[:, 1], 0, h) # y1 | |
gt_bboxes[:, 2] = np.clip(gt_bboxes[:, 2], 0, w) # x2 | |
gt_bboxes[:, 3] = np.clip(gt_bboxes[:, 3], 0, h) # y2 | |
# Don't drop anything here - let filter_cfg handle empty GT filtering | |
results['gt_bboxes'] = gt_bboxes | |
return results | |
class SetScaleFactor(BaseTransform): | |
"""Compute scale_factor from data_series & plot_bb before any Resize.""" | |
def __init__(self, default_scale=(1.0, 1.0)): | |
self.default_scale = default_scale | |
def calculate_scale_factor(self, results): | |
bb = results.get('plot_bb', {}) | |
w, h = bb.get('width', 0), bb.get('height', 0) | |
xs, ys = [], [] | |
for series in results.get('data_series', []): | |
for pt in series.get('data', []): | |
x, y = pt.get('x'), pt.get('y') | |
if isinstance(x, (int, float)): xs.append(x) | |
if isinstance(y, (int, float)): ys.append(y) | |
if xs and max(xs) != min(xs): | |
x_scale = w / (max(xs) - min(xs)) | |
else: | |
x_scale = self.default_scale[0] | |
if ys and max(ys) != min(ys): | |
y_scale = -h / (max(ys) - min(ys)) | |
else: | |
y_scale = self.default_scale[1] | |
return (x_scale, y_scale) | |
def transform(self, results): | |
try: | |
sf = self.calculate_scale_factor(results) | |
results['scale_factor'] = np.array(sf, dtype=np.float32) | |
except Exception: | |
results['scale_factor'] = np.array(self.default_scale, dtype=np.float32) | |
H, W = results.get('height', 0), results.get('width', 0) | |
results['img_shape'] = (H, W, 3) | |
return results | |
class EnsureScaleFactor(BaseTransform): | |
"""Fallback if no scale_factor set yet.""" | |
def transform(self, results): | |
results['scale_factor'] = np.array([1.0, 1.0], dtype=np.float32) | |
return results | |
class SetInputs(BaseTransform): | |
"""Copy dummy img into inputs for DetDataPreprocessor.""" | |
def transform(self, results): | |
if 'img' in results: | |
results['inputs'] = results['img'].copy() | |
return results | |
class CustomPackDetInputs(PackDetInputs): | |
"""Final packing into DetDataSample, ensure inputs present.""" | |
def transform(self, results): | |
if 'img' in results: | |
results['inputs'] = results['img'].copy() | |
return super().transform(results) | |
class ChartDataset(BaseDetDataset): | |
"""Enhanced dataset for comprehensive chart element detection and analysis.""" | |
# Updated METAINFO with 21 enhanced categories | |
METAINFO = { | |
'classes': [ | |
'title', 'subtitle', 'x-axis', 'y-axis', 'x-axis-label', 'y-axis-label', | |
'x-tick-label', 'y-tick-label', 'legend', 'legend-title', 'legend-item', | |
'data-point', 'data-line', 'data-bar', 'data-area', 'grid-line', | |
'axis-title', 'tick-label', 'data-label', 'legend-text', 'plot-area' | |
] | |
} | |
# Chart-type specific element filtering based on actual dataset distribution | |
# Data from analyze_chart_types.py: | |
# β’ line (41.9%): 1710 images β data-line only | |
# β’ scatter (18.2%): 742 images β data-point only | |
# β’ vertical_bar (30.5%): 1246 images β data-bar only | |
# β’ dot (9.2%): 374 images β data-point only | |
# β’ horizontal_bar (0.2%): 9 images β data-bar only | |
CHART_TYPE_ELEMENT_MAPPING = { | |
# Line charts (41.9% - 1710 images): ONLY data-line | |
'line': { | |
'allowed_data_elements': {'data-line'}, | |
'forbidden_data_elements': {'data-point', 'data-bar', 'data-area'} | |
}, | |
# Scatter charts (18.2% - 742 images): ONLY data-point | |
'scatter': { | |
'allowed_data_elements': {'data-point'}, | |
'forbidden_data_elements': {'data-line', 'data-bar', 'data-area'} | |
}, | |
# Vertical bar charts (30.5% - 1246 images): ONLY data-bar | |
'vertical_bar': { | |
'allowed_data_elements': {'data-bar'}, | |
'forbidden_data_elements': {'data-point', 'data-line', 'data-area'} | |
}, | |
# Dot charts (9.2% - 374 images): ONLY data-point | |
'dot': { | |
'allowed_data_elements': {'data-point'}, | |
'forbidden_data_elements': {'data-line', 'data-bar', 'data-area'} | |
}, | |
# Horizontal bar charts (0.2% - 9 images): ONLY data-bar | |
'horizontal_bar': { | |
'allowed_data_elements': {'data-bar'}, | |
'forbidden_data_elements': {'data-point', 'data-line', 'data-area'} | |
} | |
} | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.metainfo.update(self.METAINFO) | |
# Print configuration info | |
print(f"π ChartDataset initialized with {len(self.METAINFO['classes'])} categories:") | |
for i, cls_name in enumerate(self.METAINFO['classes']): | |
print(f" {i}: {cls_name}") | |
# Print chart-type filtering info | |
print(f"π― Chart-type specific filtering enabled:") | |
for chart_type, mapping in self.CHART_TYPE_ELEMENT_MAPPING.items(): | |
allowed = mapping.get('allowed_data_elements', set()) | |
forbidden = mapping.get('forbidden_data_elements', set()) | |
print(f" β’ {chart_type}: β {allowed} | π« {forbidden}") | |
# Debug print the data configuration | |
print(f"π Dataset configuration:") | |
print(f" β’ data_root: {getattr(self, 'data_root', 'None')}") | |
print(f" β’ data_prefix: {getattr(self, 'data_prefix', 'None')}") | |
print(f" β’ ann_file: {getattr(self, 'ann_file', 'None')}") | |
def load_data_list(self): | |
"""Load enhanced annotation files with priority order.""" | |
# Auto-detect best annotation file (same logic as config) | |
def get_best_ann_file(split): | |
ann_dir = osp.join(self.data_root, 'annotations_JSON') | |
# Priority order with flexible naming | |
candidates = [ | |
f'{split}_enriched_with_info.json', | |
f'{split}_enriched.json', | |
f'{split}_with_info.json', # Added: Handles val_with_info.json | |
f'{split}.json', | |
f'{split}_cleaned.json' | |
] | |
for candidate in candidates: | |
full_path = osp.join(ann_dir, candidate) | |
if osp.exists(full_path): | |
print(f"π ChartDataset using {candidate}") | |
return full_path | |
# Fallback to ann_file if specified | |
if hasattr(self, 'ann_file') and self.ann_file: | |
fallback_path = osp.join(self.data_root, self.ann_file) | |
if osp.exists(fallback_path): | |
print(f"π Using fallback annotation file: {self.ann_file}") | |
return fallback_path | |
raise FileNotFoundError(f"No annotation files found in {ann_dir}") | |
# Determine file path | |
if hasattr(self, 'ann_file') and self.ann_file: | |
ann_file_path = osp.join(self.data_root, self.ann_file) | |
else: | |
# Try to auto-detect based on common patterns | |
for split in ['train', 'val']: | |
try: | |
ann_file_path = get_best_ann_file(split) | |
break | |
except FileNotFoundError: | |
continue | |
else: | |
raise FileNotFoundError("Could not find any annotation files") | |
# Load annotation file | |
with open(ann_file_path, 'r') as f: | |
ann = json.load(f) | |
print(f"π Loading from {ann_file_path}") | |
print(f" β’ Images: {len(ann.get('images', []))}") | |
print(f" β’ Annotations: {len(ann.get('annotations', []))}") | |
# Build image lookup | |
img_id_to_info = {img['id']: img for img in ann['images']} | |
# Group annotations by image | |
img_id_to_anns = {} | |
for ann_data in ann.get('annotations', []): | |
img_id = ann_data['image_id'] | |
if img_id not in img_id_to_anns: | |
img_id_to_anns[img_id] = [] | |
img_id_to_anns[img_id].append(ann_data) | |
# Create data list with enhanced metadata | |
data_list = [] | |
for img_id, img_info in img_id_to_info.items(): | |
annotations = img_id_to_anns.get(img_id, []) | |
# Skip images without annotations if filter_empty_gt is enabled | |
if not annotations and self.filter_cfg.get('filter_empty_gt', False): | |
continue | |
# Convert annotations to instances format | |
instances = [] | |
for ann in annotations: | |
bbox = ann['bbox'] # [x, y, width, height] | |
# Convert to [x1, y1, x2, y2] format for MMDet | |
bbox_xyxy = [bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3]] | |
instance = { | |
'bbox': bbox_xyxy, | |
'bbox_label': ann['category_id'], | |
'ignore_flag': 0, | |
'annotation_id': ann.get('id', -1), | |
'area': ann.get('area', bbox[2] * bbox[3]), | |
'element_type': ann.get('element_type', 'unknown') | |
} | |
# Add additional annotation metadata if available | |
for key in ['text', 'role', 'data_point', 'chart_type', 'total_data_points']: | |
if key in ann: | |
instance[key] = ann[key] | |
instances.append(instance) | |
# Create data info with enhanced metadata | |
# Fix: Construct full image path using data_prefix (like standard MMDet datasets) | |
filename = img_info['file_name'] | |
if self.data_prefix.get('img'): | |
img_path = osp.join(self.data_prefix['img'], filename) | |
else: | |
img_path = filename # Fallback to original filename | |
data_info = { | |
'img_id': img_info['id'], | |
'img_path': img_path, # Use constructed path | |
'height': img_info['height'], | |
'width': img_info['width'], | |
'instances': instances, | |
# Enhanced metadata from enriched annotations | |
'chart_type': img_info.get('chart_type', ''), | |
'plot_bb': img_info.get('plot_bb', {}), | |
'data_series': img_info.get('data_series', []), | |
'data_series_stats': img_info.get('data_series_stats', {}), | |
'axes_info': img_info.get('axes_info', {}), | |
'element_counts': img_info.get('element_counts', {}), | |
'source': img_info.get('source', 'unknown') | |
} | |
data_list.append(data_info) | |
print(f"β Loaded {len(data_list)} images with enhanced metadata") | |
return data_list | |
def parse_data_info(self, raw_data_info): | |
"""Parse data info with enhanced metadata support.""" | |
d = raw_data_info.copy() | |
# Debug logging for first few images to verify path construction | |
if hasattr(self, '_debug_count'): | |
self._debug_count += 1 | |
else: | |
self._debug_count = 1 | |
if self._debug_count <= 3: | |
print(f"π Path verification debug #{self._debug_count}:") | |
print(f" β’ img_path from load_data_list: {d['img_path']}") | |
print(f" β’ data_root: {getattr(self, 'data_root', 'None')}") | |
full_path = osp.join(self.data_root, d['img_path']) if hasattr(self, 'data_root') else d['img_path'] | |
print(f" β’ Full absolute path: {full_path}") | |
print(f" β’ Path exists: {osp.exists(full_path)}") | |
# Create or get image information | |
img_h, img_w = d['height'], d['width'] | |
# Get class names for class-specific filtering | |
class_names = self.METAINFO['classes'] | |
# Get filter configuration | |
min_size = self.filter_cfg.get('min_size', 1) | |
class_specific_min_sizes = self.filter_cfg.get('class_specific_min_sizes', {}) | |
# Handle bboxes and labels from instances with enhanced filtering | |
bboxes, labels = [], [] | |
filtered_count = 0 | |
enlarged_count = 0 | |
chart_type_filtered_count = 0 | |
# Get chart type for filtering | |
chart_type = d.get('chart_type', '').lower() | |
chart_mapping = self.CHART_TYPE_ELEMENT_MAPPING.get(chart_type, {}) | |
allowed_data_elements = chart_mapping.get('allowed_data_elements', set()) | |
forbidden_data_elements = chart_mapping.get('forbidden_data_elements', set()) | |
for inst in d.get('instances', []): | |
bbox = inst['bbox'] | |
label_id = inst['bbox_label'] | |
# Get class name for this label | |
class_name = class_names[label_id] if 0 <= label_id < len(class_names) else 'unknown' | |
# Chart-type specific filtering: Skip forbidden data elements | |
if chart_type and class_name in forbidden_data_elements: | |
chart_type_filtered_count += 1 | |
if self._debug_count <= 3 and chart_type_filtered_count <= 3: | |
print(f" π« Filtered {class_name} from {chart_type} chart (inappropriate data element)") | |
continue | |
# Chart-type specific validation: Log allowed data elements | |
if chart_type and class_name in allowed_data_elements: | |
if self._debug_count <= 3: | |
print(f" β Keeping {class_name} for {chart_type} chart (appropriate data element)") | |
# Validate and clamp bbox | |
x1, y1, x2, y2 = bbox | |
x1 = max(0, min(x1, img_w)) | |
y1 = max(0, min(y1, img_h)) | |
x2 = max(x1, min(x2, img_w)) | |
y2 = max(y1, min(y2, img_h)) | |
# Skip invalid bboxes | |
if x2 <= x1 or y2 <= y1: | |
filtered_count += 1 | |
continue | |
# Calculate current bbox dimensions | |
bbox_w = x2 - x1 | |
bbox_h = y2 - y1 | |
bbox_min_dim = min(bbox_w, bbox_h) | |
# Check class-specific minimum size | |
required_min_size = class_specific_min_sizes.get(class_name, min_size) | |
# If bbox is smaller than required, enlarge it to meet the minimum size | |
if bbox_min_dim < required_min_size: | |
# Calculate expansion needed | |
expand_w = max(0, required_min_size - bbox_w) / 2 | |
expand_h = max(0, required_min_size - bbox_h) / 2 | |
# Expand bbox while keeping it within image bounds | |
new_x1 = max(0, x1 - expand_w) | |
new_y1 = max(0, y1 - expand_h) | |
new_x2 = min(img_w, x2 + expand_w) | |
new_y2 = min(img_h, y2 + expand_h) | |
# Update bbox coordinates | |
x1, y1, x2, y2 = new_x1, new_y1, new_x2, new_y2 | |
enlarged_count += 1 | |
if self._debug_count <= 3 and enlarged_count <= 3: | |
print(f" π Enlarged {class_name} bbox: {bbox_w:.1f}x{bbox_h:.1f} β {(x2-x1):.1f}x{(y2-y1):.1f}") | |
bboxes.append([x1, y1, x2, y2]) | |
labels.append(label_id) | |
# Log filtering and enlargement statistics for first few images | |
if self._debug_count <= 3: | |
print(f" π Bbox processing: {len(bboxes)} kept, {filtered_count} filtered (invalid), {chart_type_filtered_count} filtered (chart-type), {enlarged_count} enlarged") | |
if chart_type: | |
print(f" π Chart type: {chart_type} | Allowed data elements: {allowed_data_elements}") | |
if forbidden_data_elements: | |
print(f" π« Forbidden data elements for {chart_type}: {forbidden_data_elements}") | |
# Convert to arrays | |
d['gt_bboxes'] = np.array(bboxes, dtype=np.float32) if bboxes else np.zeros((0, 4), dtype=np.float32) | |
d['gt_bboxes_labels'] = np.array(labels, dtype=np.int64) if labels else np.zeros((0,), dtype=np.int64) | |
# Enhanced scale factor calculation using data_series_stats | |
d['scale_factor'] = np.array([1.0, 1.0], dtype=np.float32) | |
# Use enhanced metadata for better scale factor calculation | |
data_series_stats = d.get('data_series_stats', {}) | |
plot_bb = d.get('plot_bb', {}) | |
if data_series_stats and plot_bb and all(k in plot_bb for k in ['width', 'height']): | |
x_range = data_series_stats.get('x_range') | |
y_range = data_series_stats.get('y_range') | |
if x_range and len(x_range) == 2 and x_range[1] != x_range[0]: | |
d['scale_factor'][0] = plot_bb['width'] / (x_range[1] - x_range[0]) | |
if y_range and len(y_range) == 2 and y_range[1] != y_range[0]: | |
d['scale_factor'][1] = -plot_bb['height'] / (y_range[1] - y_range[0]) | |
# Required MMDet fields | |
d.update({ | |
'img_shape': (img_h, img_w, 3), | |
'ori_shape': (img_h, img_w, 3), | |
'pad_shape': (img_h, img_w, 3), | |
'flip': False, | |
'flip_direction': None, | |
'img_fields': ['img'], | |
'bbox_fields': ['bbox'], | |
}) | |
# Additional metadata for training | |
d['img_info'] = { | |
'height': img_h, | |
'width': img_w, | |
'img_shape': d['img_shape'], | |
'ori_shape': d['ori_shape'], | |
'pad_shape': d['pad_shape'], | |
'scale_factor': d['scale_factor'].copy(), | |
'flip': d['flip'], | |
'flip_direction': d['flip_direction'], | |
# Enhanced metadata | |
'chart_type': d.get('chart_type', ''), | |
'num_data_points': data_series_stats.get('num_data_points', 0), | |
'element_counts': d.get('element_counts', {}) | |
} | |
return d | |
def print_missing_image_summary(): | |
"""Print summary of missing images.""" | |
count = RobustLoadImageFromFile.get_missing_count() | |
if count > 0: | |
print(f"π MISSING IMAGES SUMMARY: {count} images were not found and replaced with dummy images") | |
else: | |
print("β All images loaded successfully!") | |
def print_dataset_summary(): | |
"""Print summary of dataset configuration.""" | |
print("π ENHANCED CHART DATASET SUMMARY:") | |
print(f" β’ 21 categories supported for comprehensive chart element detection") | |
print(f" β’ Auto-detects best annotation files (enriched_with_info > enriched > regular)") | |
print(f" β’ Enhanced metadata: chart_type, data_series_stats, element_counts, axes_info") | |
print(f" β’ Robust image loading with fallback to dummy images") | |
print(f" β’ Multiple annotations per image (not just plot areas)") | |
print("β [PLUGIN] Enhanced ChartDataset + transforms registered!") | |
print_dataset_summary() | |