File size: 7,123 Bytes
06966eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
import colorgram
import cv2
import numpy as np
from PIL import Image
import json
import torch
from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
import functools


class DesignTokenExtractor:
    def __init__(self):
        # Load models once at startup
        self.pix2struct_model = None
        self.pix2struct_processor = None
        self._load_models()
        
    @functools.lru_cache(maxsize=1)
    def _load_models(self):
        """Load models with caching to prevent repeated initialization"""
        try:
            self.pix2struct_processor = Pix2StructProcessor.from_pretrained(
                "google/pix2struct-screen2words-base"
            )
            self.pix2struct_model = Pix2StructForConditionalGeneration.from_pretrained(
                "google/pix2struct-screen2words-base",
                torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
            )
        except Exception as e:
            print(f"Warning: Could not load Pix2Struct model: {e}")
            # Continue without the model for basic extraction
        
    def extract_colors(self, image_path, num_colors=8):
        """Extract dominant colors using colorgram"""
        try:
            colors = colorgram.extract(image_path, num_colors)
            palette = {}
            
            for i, color in enumerate(colors):
                # Determine semantic color role based on proportion
                if i == 0 and color.proportion > 0.3:
                    name = "background"
                elif i == 1:
                    name = "primary"
                elif i == 2:
                    name = "secondary"
                else:
                    name = f"accent-{i-2}"
                    
                palette[name] = {
                    "hex": f"#{color.rgb.r:02x}{color.rgb.g:02x}{color.rgb.b:02x}",
                    "rgb": f"rgb({color.rgb.r}, {color.rgb.g}, {color.rgb.b})",
                    "proportion": round(color.proportion, 3)
                }
            
            return palette
        except Exception as e:
            print(f"Error extracting colors: {e}")
            return self._get_default_colors()
    
    def detect_spacing(self, image):
        """Analyze spacing patterns using OpenCV"""
        try:
            gray = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
            edges = cv2.Canny(gray, 50, 150)
            
            # Find contours for element detection
            contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            
            # Calculate spacing between elements
            bounding_boxes = [cv2.boundingRect(c) for c in contours if cv2.contourArea(c) > 100]
            
            if len(bounding_boxes) > 1:
                # Sort by y-coordinate to find vertical spacing
                bounding_boxes.sort(key=lambda x: x[1])
                
                vertical_gaps = []
                for i in range(len(bounding_boxes)-1):
                    gap = bounding_boxes[i+1][1] - (bounding_boxes[i][1] + bounding_boxes[i][3])
                    if gap > 0:
                        vertical_gaps.append(gap)
                
                # Find common spacing values using clustering
                spacing_system = self._cluster_spacing_values(vertical_gaps)
                return spacing_system
        except Exception as e:
            print(f"Error detecting spacing: {e}")
        
        return {"small": "8px", "medium": "16px", "large": "32px"}  # Defaults
    
    def _cluster_spacing_values(self, gaps):
        """Group similar spacing values"""
        if not gaps:
            return {"small": "8px", "medium": "16px", "large": "32px"}
        
        gaps.sort()
        
        # Simple clustering for common spacing values
        unique_gaps = list(set(gaps))
        
        if len(unique_gaps) >= 3:
            return {
                "small": f"{unique_gaps[0]}px",
                "medium": f"{unique_gaps[len(unique_gaps)//2]}px",
                "large": f"{unique_gaps[-1]}px"
            }
        elif len(unique_gaps) == 2:
            return {
                "small": f"{unique_gaps[0]}px",
                "large": f"{unique_gaps[1]}px"
            }
        
        return {"base": f"{unique_gaps[0]}px" if unique_gaps else "16px"}
    
    def analyze_components(self, image):
        """Use Pix2Struct for component understanding"""
        if self.pix2struct_model is None or self.pix2struct_processor is None:
            # Fallback if model loading failed
            return {
                "detected_elements": "Model not available - basic extraction only",
                "layout": "responsive"
            }
        
        try:
            inputs = self.pix2struct_processor(images=image, return_tensors="pt")
            
            with torch.no_grad():
                generated_ids = self.pix2struct_model.generate(**inputs, max_length=100)
            
            description = self.pix2struct_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
            
            # Parse description for component types
            components = {
                "detected_elements": description,
                "layout": "responsive" if "responsive" in description.lower() else "fixed"
            }
            
            return components
        except Exception as e:
            print(f"Error analyzing components: {e}")
            return {
                "detected_elements": "Error during analysis",
                "layout": "responsive"
            }
    
    def detect_typography(self, image):
        """Basic typography detection"""
        # Simplified typography detection without EasyOCR for initial implementation
        return {
            "heading": {
                "family": "sans-serif",
                "size": "32px",
                "weight": "700"
            },
            "body": {
                "family": "sans-serif", 
                "size": "16px",
                "weight": "400"
            },
            "caption": {
                "family": "sans-serif",
                "size": "14px",
                "weight": "400"
            }
        }
    
    def _get_default_colors(self):
        """Return default color palette"""
        return {
            "primary": {"hex": "#3B82F6", "rgb": "rgb(59, 130, 246)", "proportion": 0.25},
            "secondary": {"hex": "#8B5CF6", "rgb": "rgb(139, 92, 246)", "proportion": 0.15},
            "background": {"hex": "#FFFFFF", "rgb": "rgb(255, 255, 255)", "proportion": 0.40},
            "text": {"hex": "#1F2937", "rgb": "rgb(31, 41, 55)", "proportion": 0.20}
        }
    
    def resize_for_processing(self, image, max_dimension=1024):
        """Resize large images while maintaining aspect ratio"""
        if max(image.size) > max_dimension:
            ratio = max_dimension / max(image.size)
            new_size = tuple(int(dim * ratio) for dim in image.size)
            return image.resize(new_size, Image.Resampling.LANCZOS)
        return image