Spaces:
Sleeping
Sleeping

Integrate DeepFashion2 dataset: add evaluation module, utilities, and API endpoints for dataset management and analysis
f8b306b
""" | |
DeepFashion2 Dataset Integration Utilities | |
Provides tools for loading, processing, and using the DeepFashion2 dataset | |
with the Vestiq fashion analysis system. | |
""" | |
import os | |
import json | |
import torch | |
import numpy as np | |
from PIL import Image | |
from torch.utils.data import Dataset, DataLoader | |
from pathlib import Path | |
from typing import Dict, List, Tuple, Optional, Union | |
import torchvision.transforms as transforms | |
from dataclasses import dataclass, field | |
import requests | |
import zipfile | |
import shutil | |
class DeepFashion2Config: | |
"""Configuration for DeepFashion2 dataset""" | |
dataset_root: str = "./data/deepfashion2" | |
download_url: str = "https://github.com/switchablenorms/DeepFashion2/releases/download/v1.0/deepfashion2.zip" | |
categories: List[str] = field(default_factory=list) | |
image_size: Tuple[int, int] = (224, 224) | |
batch_size: int = 32 | |
num_workers: int = 4 | |
def __post_init__(self): | |
if not self.categories: | |
# DeepFashion2 13 categories | |
self.categories = [ | |
'short_sleeved_shirt', 'long_sleeved_shirt', 'short_sleeved_outwear', | |
'long_sleeved_outwear', 'vest', 'sling', 'shorts', 'trousers', | |
'skirt', 'short_sleeved_dress', 'long_sleeved_dress', 'vest_dress', 'sling_dress' | |
] | |
class DeepFashion2CategoryMapper: | |
"""Maps DeepFashion2 categories to yainage90 model categories""" | |
def __init__(self): | |
# Mapping from DeepFashion2 categories to yainage90 categories | |
self.df2_to_yainage90 = { | |
'short_sleeved_shirt': 'top', | |
'long_sleeved_shirt': 'top', | |
'short_sleeved_outwear': 'outer', | |
'long_sleeved_outwear': 'outer', | |
'vest': 'top', | |
'sling': 'top', | |
'shorts': 'bottom', | |
'trousers': 'bottom', | |
'skirt': 'bottom', | |
'short_sleeved_dress': 'dress', | |
'long_sleeved_dress': 'dress', | |
'vest_dress': 'dress', | |
'sling_dress': 'dress' | |
} | |
# Reverse mapping | |
self.yainage90_to_df2 = {} | |
for df2_cat, yainage_cat in self.df2_to_yainage90.items(): | |
if yainage_cat not in self.yainage90_to_df2: | |
self.yainage90_to_df2[yainage_cat] = [] | |
self.yainage90_to_df2[yainage_cat].append(df2_cat) | |
def map_to_yainage90(self, df2_category: str) -> str: | |
"""Map DeepFashion2 category to yainage90 category""" | |
return self.df2_to_yainage90.get(df2_category, 'unknown') | |
def map_from_yainage90(self, yainage_category: str) -> List[str]: | |
"""Map yainage90 category to DeepFashion2 categories""" | |
return self.yainage90_to_df2.get(yainage_category, []) | |
class DeepFashion2Dataset(Dataset): | |
"""PyTorch Dataset for DeepFashion2""" | |
def __init__(self, | |
root_dir: str, | |
split: str = 'train', | |
transform: Optional[transforms.Compose] = None, | |
load_annotations: bool = True): | |
""" | |
Initialize DeepFashion2 dataset | |
Args: | |
root_dir: Root directory of DeepFashion2 dataset | |
split: Dataset split ('train', 'validation', 'test') | |
transform: Image transformations | |
load_annotations: Whether to load bounding box annotations | |
""" | |
self.root_dir = Path(root_dir) | |
self.split = split | |
self.transform = transform | |
self.load_annotations = load_annotations | |
self.category_mapper = DeepFashion2CategoryMapper() | |
# Load dataset metadata | |
self.images_dir = self.root_dir / split / "image" | |
self.annos_dir = self.root_dir / split / "annos" | |
# Get all image files | |
self.image_files = [] | |
if self.images_dir.exists(): | |
self.image_files = list(self.images_dir.glob("*.jpg")) | |
print(f"Found {len(self.image_files)} images in {split} split") | |
def __len__(self): | |
return len(self.image_files) | |
def __getitem__(self, idx): | |
"""Get dataset item""" | |
image_path = self.image_files[idx] | |
image_name = image_path.stem | |
# Load image | |
image = Image.open(image_path).convert('RGB') | |
# Load annotations if requested | |
annotations = None | |
if self.load_annotations: | |
anno_path = self.annos_dir / f"{image_name}.json" | |
if anno_path.exists(): | |
with open(anno_path, 'r') as f: | |
annotations = json.load(f) | |
# Apply transforms | |
if self.transform: | |
image = self.transform(image) | |
return { | |
'image': image, | |
'image_path': str(image_path), | |
'image_name': image_name, | |
'annotations': annotations | |
} | |
def get_categories_in_image(self, annotations: Dict) -> List[str]: | |
"""Extract categories from annotations""" | |
if not annotations or 'item' not in annotations: | |
return [] | |
categories = [] | |
for item_id, item_data in annotations['item'].items(): | |
if 'category_name' in item_data: | |
categories.append(item_data['category_name']) | |
return list(set(categories)) | |
class DeepFashion2Downloader: | |
"""Download and setup DeepFashion2 dataset""" | |
def __init__(self, config: DeepFashion2Config): | |
self.config = config | |
self.dataset_root = Path(config.dataset_root) | |
def download_dataset(self, force_download: bool = False) -> bool: | |
""" | |
Download DeepFashion2 dataset | |
Args: | |
force_download: Force re-download even if dataset exists | |
Returns: | |
True if successful, False otherwise | |
""" | |
if self.dataset_root.exists() and not force_download: | |
print(f"Dataset already exists at {self.dataset_root}") | |
return True | |
print("DeepFashion2 dataset download requires manual setup.") | |
print("Please follow these steps:") | |
print("1. Visit: https://github.com/switchablenorms/DeepFashion2") | |
print("2. Follow the dataset download instructions") | |
print("3. Extract the dataset to:", self.dataset_root) | |
print("4. Ensure the directory structure is:") | |
print(" deepfashion2/") | |
print(" βββ train/") | |
print(" β βββ image/") | |
print(" β βββ annos/") | |
print(" βββ validation/") | |
print(" β βββ image/") | |
print(" β βββ annos/") | |
print(" βββ test/") | |
print(" βββ image/") | |
print(" βββ annos/") | |
return False | |
def verify_dataset(self) -> bool: | |
"""Verify dataset structure""" | |
required_dirs = [ | |
self.dataset_root / "train" / "image", | |
self.dataset_root / "train" / "annos", | |
self.dataset_root / "validation" / "image", | |
self.dataset_root / "validation" / "annos" | |
] | |
for dir_path in required_dirs: | |
if not dir_path.exists(): | |
print(f"Missing required directory: {dir_path}") | |
return False | |
print("Dataset structure verified successfully") | |
return True | |
def create_deepfashion2_transforms(image_size: Tuple[int, int] = (224, 224)) -> transforms.Compose: | |
"""Create standard transforms for DeepFashion2 images""" | |
return transforms.Compose([ | |
transforms.Resize(image_size), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
def create_deepfashion2_dataloader(config: DeepFashion2Config, | |
split: str = 'train', | |
shuffle: bool = True) -> DataLoader: | |
"""Create DataLoader for DeepFashion2 dataset""" | |
transform = create_deepfashion2_transforms(config.image_size) | |
dataset = DeepFashion2Dataset( | |
root_dir=config.dataset_root, | |
split=split, | |
transform=transform, | |
load_annotations=True | |
) | |
return DataLoader( | |
dataset, | |
batch_size=config.batch_size, | |
shuffle=shuffle, | |
num_workers=config.num_workers, | |
pin_memory=torch.cuda.is_available() | |
) | |
def get_deepfashion2_statistics(config: DeepFashion2Config) -> Dict: | |
"""Get statistics about the DeepFashion2 dataset""" | |
stats = { | |
'splits': {}, | |
'total_images': 0, | |
'categories': config.categories, | |
'category_counts': {cat: 0 for cat in config.categories} | |
} | |
for split in ['train', 'validation', 'test']: | |
try: | |
dataset = DeepFashion2Dataset( | |
root_dir=config.dataset_root, | |
split=split, | |
transform=None, | |
load_annotations=True | |
) | |
split_stats = { | |
'num_images': len(dataset), | |
'categories_found': set() | |
} | |
# Sample a few images to get category statistics | |
sample_size = min(100, len(dataset)) | |
for i in range(0, len(dataset), max(1, len(dataset) // sample_size)): | |
item = dataset[i] | |
if item['annotations']: | |
categories = dataset.get_categories_in_image(item['annotations']) | |
split_stats['categories_found'].update(categories) | |
for cat in categories: | |
if cat in stats['category_counts']: | |
stats['category_counts'][cat] += 1 | |
split_stats['categories_found'] = list(split_stats['categories_found']) | |
stats['splits'][split] = split_stats | |
stats['total_images'] += split_stats['num_images'] | |
except Exception as e: | |
print(f"Error processing {split} split: {e}") | |
stats['splits'][split] = {'error': str(e)} | |
return stats | |