nextussocial's picture
Implement complete Design Token Extractor system
06966eb
import colorgram
import cv2
import numpy as np
from PIL import Image
import json
import torch
from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
import functools
class DesignTokenExtractor:
def __init__(self):
# Load models once at startup
self.pix2struct_model = None
self.pix2struct_processor = None
self._load_models()
@functools.lru_cache(maxsize=1)
def _load_models(self):
"""Load models with caching to prevent repeated initialization"""
try:
self.pix2struct_processor = Pix2StructProcessor.from_pretrained(
"google/pix2struct-screen2words-base"
)
self.pix2struct_model = Pix2StructForConditionalGeneration.from_pretrained(
"google/pix2struct-screen2words-base",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
except Exception as e:
print(f"Warning: Could not load Pix2Struct model: {e}")
# Continue without the model for basic extraction
def extract_colors(self, image_path, num_colors=8):
"""Extract dominant colors using colorgram"""
try:
colors = colorgram.extract(image_path, num_colors)
palette = {}
for i, color in enumerate(colors):
# Determine semantic color role based on proportion
if i == 0 and color.proportion > 0.3:
name = "background"
elif i == 1:
name = "primary"
elif i == 2:
name = "secondary"
else:
name = f"accent-{i-2}"
palette[name] = {
"hex": f"#{color.rgb.r:02x}{color.rgb.g:02x}{color.rgb.b:02x}",
"rgb": f"rgb({color.rgb.r}, {color.rgb.g}, {color.rgb.b})",
"proportion": round(color.proportion, 3)
}
return palette
except Exception as e:
print(f"Error extracting colors: {e}")
return self._get_default_colors()
def detect_spacing(self, image):
"""Analyze spacing patterns using OpenCV"""
try:
gray = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
edges = cv2.Canny(gray, 50, 150)
# Find contours for element detection
contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# Calculate spacing between elements
bounding_boxes = [cv2.boundingRect(c) for c in contours if cv2.contourArea(c) > 100]
if len(bounding_boxes) > 1:
# Sort by y-coordinate to find vertical spacing
bounding_boxes.sort(key=lambda x: x[1])
vertical_gaps = []
for i in range(len(bounding_boxes)-1):
gap = bounding_boxes[i+1][1] - (bounding_boxes[i][1] + bounding_boxes[i][3])
if gap > 0:
vertical_gaps.append(gap)
# Find common spacing values using clustering
spacing_system = self._cluster_spacing_values(vertical_gaps)
return spacing_system
except Exception as e:
print(f"Error detecting spacing: {e}")
return {"small": "8px", "medium": "16px", "large": "32px"} # Defaults
def _cluster_spacing_values(self, gaps):
"""Group similar spacing values"""
if not gaps:
return {"small": "8px", "medium": "16px", "large": "32px"}
gaps.sort()
# Simple clustering for common spacing values
unique_gaps = list(set(gaps))
if len(unique_gaps) >= 3:
return {
"small": f"{unique_gaps[0]}px",
"medium": f"{unique_gaps[len(unique_gaps)//2]}px",
"large": f"{unique_gaps[-1]}px"
}
elif len(unique_gaps) == 2:
return {
"small": f"{unique_gaps[0]}px",
"large": f"{unique_gaps[1]}px"
}
return {"base": f"{unique_gaps[0]}px" if unique_gaps else "16px"}
def analyze_components(self, image):
"""Use Pix2Struct for component understanding"""
if self.pix2struct_model is None or self.pix2struct_processor is None:
# Fallback if model loading failed
return {
"detected_elements": "Model not available - basic extraction only",
"layout": "responsive"
}
try:
inputs = self.pix2struct_processor(images=image, return_tensors="pt")
with torch.no_grad():
generated_ids = self.pix2struct_model.generate(**inputs, max_length=100)
description = self.pix2struct_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
# Parse description for component types
components = {
"detected_elements": description,
"layout": "responsive" if "responsive" in description.lower() else "fixed"
}
return components
except Exception as e:
print(f"Error analyzing components: {e}")
return {
"detected_elements": "Error during analysis",
"layout": "responsive"
}
def detect_typography(self, image):
"""Basic typography detection"""
# Simplified typography detection without EasyOCR for initial implementation
return {
"heading": {
"family": "sans-serif",
"size": "32px",
"weight": "700"
},
"body": {
"family": "sans-serif",
"size": "16px",
"weight": "400"
},
"caption": {
"family": "sans-serif",
"size": "14px",
"weight": "400"
}
}
def _get_default_colors(self):
"""Return default color palette"""
return {
"primary": {"hex": "#3B82F6", "rgb": "rgb(59, 130, 246)", "proportion": 0.25},
"secondary": {"hex": "#8B5CF6", "rgb": "rgb(139, 92, 246)", "proportion": 0.15},
"background": {"hex": "#FFFFFF", "rgb": "rgb(255, 255, 255)", "proportion": 0.40},
"text": {"hex": "#1F2937", "rgb": "rgb(31, 41, 55)", "proportion": 0.20}
}
def resize_for_processing(self, image, max_dimension=1024):
"""Resize large images while maintaining aspect ratio"""
if max(image.size) > max_dimension:
ratio = max_dimension / max(image.size)
new_size = tuple(int(dim * ratio) for dim in image.size)
return image.resize(new_size, Image.Resampling.LANCZOS)
return image