Spaces:
Sleeping
Sleeping
import colorgram | |
import cv2 | |
import numpy as np | |
from PIL import Image | |
import json | |
import torch | |
from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor | |
import functools | |
class DesignTokenExtractor: | |
def __init__(self): | |
# Load models once at startup | |
self.pix2struct_model = None | |
self.pix2struct_processor = None | |
self._load_models() | |
def _load_models(self): | |
"""Load models with caching to prevent repeated initialization""" | |
try: | |
self.pix2struct_processor = Pix2StructProcessor.from_pretrained( | |
"google/pix2struct-screen2words-base" | |
) | |
self.pix2struct_model = Pix2StructForConditionalGeneration.from_pretrained( | |
"google/pix2struct-screen2words-base", | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 | |
) | |
except Exception as e: | |
print(f"Warning: Could not load Pix2Struct model: {e}") | |
# Continue without the model for basic extraction | |
def extract_colors(self, image_path, num_colors=8): | |
"""Extract dominant colors using colorgram""" | |
try: | |
colors = colorgram.extract(image_path, num_colors) | |
palette = {} | |
for i, color in enumerate(colors): | |
# Determine semantic color role based on proportion | |
if i == 0 and color.proportion > 0.3: | |
name = "background" | |
elif i == 1: | |
name = "primary" | |
elif i == 2: | |
name = "secondary" | |
else: | |
name = f"accent-{i-2}" | |
palette[name] = { | |
"hex": f"#{color.rgb.r:02x}{color.rgb.g:02x}{color.rgb.b:02x}", | |
"rgb": f"rgb({color.rgb.r}, {color.rgb.g}, {color.rgb.b})", | |
"proportion": round(color.proportion, 3) | |
} | |
return palette | |
except Exception as e: | |
print(f"Error extracting colors: {e}") | |
return self._get_default_colors() | |
def detect_spacing(self, image): | |
"""Analyze spacing patterns using OpenCV""" | |
try: | |
gray = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY) | |
edges = cv2.Canny(gray, 50, 150) | |
# Find contours for element detection | |
contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
# Calculate spacing between elements | |
bounding_boxes = [cv2.boundingRect(c) for c in contours if cv2.contourArea(c) > 100] | |
if len(bounding_boxes) > 1: | |
# Sort by y-coordinate to find vertical spacing | |
bounding_boxes.sort(key=lambda x: x[1]) | |
vertical_gaps = [] | |
for i in range(len(bounding_boxes)-1): | |
gap = bounding_boxes[i+1][1] - (bounding_boxes[i][1] + bounding_boxes[i][3]) | |
if gap > 0: | |
vertical_gaps.append(gap) | |
# Find common spacing values using clustering | |
spacing_system = self._cluster_spacing_values(vertical_gaps) | |
return spacing_system | |
except Exception as e: | |
print(f"Error detecting spacing: {e}") | |
return {"small": "8px", "medium": "16px", "large": "32px"} # Defaults | |
def _cluster_spacing_values(self, gaps): | |
"""Group similar spacing values""" | |
if not gaps: | |
return {"small": "8px", "medium": "16px", "large": "32px"} | |
gaps.sort() | |
# Simple clustering for common spacing values | |
unique_gaps = list(set(gaps)) | |
if len(unique_gaps) >= 3: | |
return { | |
"small": f"{unique_gaps[0]}px", | |
"medium": f"{unique_gaps[len(unique_gaps)//2]}px", | |
"large": f"{unique_gaps[-1]}px" | |
} | |
elif len(unique_gaps) == 2: | |
return { | |
"small": f"{unique_gaps[0]}px", | |
"large": f"{unique_gaps[1]}px" | |
} | |
return {"base": f"{unique_gaps[0]}px" if unique_gaps else "16px"} | |
def analyze_components(self, image): | |
"""Use Pix2Struct for component understanding""" | |
if self.pix2struct_model is None or self.pix2struct_processor is None: | |
# Fallback if model loading failed | |
return { | |
"detected_elements": "Model not available - basic extraction only", | |
"layout": "responsive" | |
} | |
try: | |
inputs = self.pix2struct_processor(images=image, return_tensors="pt") | |
with torch.no_grad(): | |
generated_ids = self.pix2struct_model.generate(**inputs, max_length=100) | |
description = self.pix2struct_processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
# Parse description for component types | |
components = { | |
"detected_elements": description, | |
"layout": "responsive" if "responsive" in description.lower() else "fixed" | |
} | |
return components | |
except Exception as e: | |
print(f"Error analyzing components: {e}") | |
return { | |
"detected_elements": "Error during analysis", | |
"layout": "responsive" | |
} | |
def detect_typography(self, image): | |
"""Basic typography detection""" | |
# Simplified typography detection without EasyOCR for initial implementation | |
return { | |
"heading": { | |
"family": "sans-serif", | |
"size": "32px", | |
"weight": "700" | |
}, | |
"body": { | |
"family": "sans-serif", | |
"size": "16px", | |
"weight": "400" | |
}, | |
"caption": { | |
"family": "sans-serif", | |
"size": "14px", | |
"weight": "400" | |
} | |
} | |
def _get_default_colors(self): | |
"""Return default color palette""" | |
return { | |
"primary": {"hex": "#3B82F6", "rgb": "rgb(59, 130, 246)", "proportion": 0.25}, | |
"secondary": {"hex": "#8B5CF6", "rgb": "rgb(139, 92, 246)", "proportion": 0.15}, | |
"background": {"hex": "#FFFFFF", "rgb": "rgb(255, 255, 255)", "proportion": 0.40}, | |
"text": {"hex": "#1F2937", "rgb": "rgb(31, 41, 55)", "proportion": 0.20} | |
} | |
def resize_for_processing(self, image, max_dimension=1024): | |
"""Resize large images while maintaining aspect ratio""" | |
if max(image.size) > max_dimension: | |
ratio = max_dimension / max(image.size) | |
new_size = tuple(int(dim * ratio) for dim in image.size) | |
return image.resize(new_size, Image.Resampling.LANCZOS) | |
return image |