import colorgram import cv2 import numpy as np from PIL import Image import json import torch from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor import functools class DesignTokenExtractor: def __init__(self): # Load models once at startup self.pix2struct_model = None self.pix2struct_processor = None self._load_models() @functools.lru_cache(maxsize=1) def _load_models(self): """Load models with caching to prevent repeated initialization""" try: self.pix2struct_processor = Pix2StructProcessor.from_pretrained( "google/pix2struct-screen2words-base" ) self.pix2struct_model = Pix2StructForConditionalGeneration.from_pretrained( "google/pix2struct-screen2words-base", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 ) except Exception as e: print(f"Warning: Could not load Pix2Struct model: {e}") # Continue without the model for basic extraction def extract_colors(self, image_path, num_colors=8): """Extract dominant colors using colorgram""" try: colors = colorgram.extract(image_path, num_colors) palette = {} for i, color in enumerate(colors): # Determine semantic color role based on proportion if i == 0 and color.proportion > 0.3: name = "background" elif i == 1: name = "primary" elif i == 2: name = "secondary" else: name = f"accent-{i-2}" palette[name] = { "hex": f"#{color.rgb.r:02x}{color.rgb.g:02x}{color.rgb.b:02x}", "rgb": f"rgb({color.rgb.r}, {color.rgb.g}, {color.rgb.b})", "proportion": round(color.proportion, 3) } return palette except Exception as e: print(f"Error extracting colors: {e}") return self._get_default_colors() def detect_spacing(self, image): """Analyze spacing patterns using OpenCV""" try: gray = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY) edges = cv2.Canny(gray, 50, 150) # Find contours for element detection contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) # Calculate spacing between elements bounding_boxes = [cv2.boundingRect(c) for c in contours if cv2.contourArea(c) > 100] if len(bounding_boxes) > 1: # Sort by y-coordinate to find vertical spacing bounding_boxes.sort(key=lambda x: x[1]) vertical_gaps = [] for i in range(len(bounding_boxes)-1): gap = bounding_boxes[i+1][1] - (bounding_boxes[i][1] + bounding_boxes[i][3]) if gap > 0: vertical_gaps.append(gap) # Find common spacing values using clustering spacing_system = self._cluster_spacing_values(vertical_gaps) return spacing_system except Exception as e: print(f"Error detecting spacing: {e}") return {"small": "8px", "medium": "16px", "large": "32px"} # Defaults def _cluster_spacing_values(self, gaps): """Group similar spacing values""" if not gaps: return {"small": "8px", "medium": "16px", "large": "32px"} gaps.sort() # Simple clustering for common spacing values unique_gaps = list(set(gaps)) if len(unique_gaps) >= 3: return { "small": f"{unique_gaps[0]}px", "medium": f"{unique_gaps[len(unique_gaps)//2]}px", "large": f"{unique_gaps[-1]}px" } elif len(unique_gaps) == 2: return { "small": f"{unique_gaps[0]}px", "large": f"{unique_gaps[1]}px" } return {"base": f"{unique_gaps[0]}px" if unique_gaps else "16px"} def analyze_components(self, image): """Use Pix2Struct for component understanding""" if self.pix2struct_model is None or self.pix2struct_processor is None: # Fallback if model loading failed return { "detected_elements": "Model not available - basic extraction only", "layout": "responsive" } try: inputs = self.pix2struct_processor(images=image, return_tensors="pt") with torch.no_grad(): generated_ids = self.pix2struct_model.generate(**inputs, max_length=100) description = self.pix2struct_processor.batch_decode(generated_ids, skip_special_tokens=True)[0] # Parse description for component types components = { "detected_elements": description, "layout": "responsive" if "responsive" in description.lower() else "fixed" } return components except Exception as e: print(f"Error analyzing components: {e}") return { "detected_elements": "Error during analysis", "layout": "responsive" } def detect_typography(self, image): """Basic typography detection""" # Simplified typography detection without EasyOCR for initial implementation return { "heading": { "family": "sans-serif", "size": "32px", "weight": "700" }, "body": { "family": "sans-serif", "size": "16px", "weight": "400" }, "caption": { "family": "sans-serif", "size": "14px", "weight": "400" } } def _get_default_colors(self): """Return default color palette""" return { "primary": {"hex": "#3B82F6", "rgb": "rgb(59, 130, 246)", "proportion": 0.25}, "secondary": {"hex": "#8B5CF6", "rgb": "rgb(139, 92, 246)", "proportion": 0.15}, "background": {"hex": "#FFFFFF", "rgb": "rgb(255, 255, 255)", "proportion": 0.40}, "text": {"hex": "#1F2937", "rgb": "rgb(31, 41, 55)", "proportion": 0.20} } def resize_for_processing(self, image, max_dimension=1024): """Resize large images while maintaining aspect ratio""" if max(image.size) > max_dimension: ratio = max_dimension / max(image.size) new_size = tuple(int(dim * ratio) for dim in image.size) return image.resize(new_size, Image.Resampling.LANCZOS) return image