Spaces:
Sleeping
Sleeping
File size: 7,123 Bytes
06966eb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
import colorgram
import cv2
import numpy as np
from PIL import Image
import json
import torch
from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
import functools
class DesignTokenExtractor:
def __init__(self):
# Load models once at startup
self.pix2struct_model = None
self.pix2struct_processor = None
self._load_models()
@functools.lru_cache(maxsize=1)
def _load_models(self):
"""Load models with caching to prevent repeated initialization"""
try:
self.pix2struct_processor = Pix2StructProcessor.from_pretrained(
"google/pix2struct-screen2words-base"
)
self.pix2struct_model = Pix2StructForConditionalGeneration.from_pretrained(
"google/pix2struct-screen2words-base",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
except Exception as e:
print(f"Warning: Could not load Pix2Struct model: {e}")
# Continue without the model for basic extraction
def extract_colors(self, image_path, num_colors=8):
"""Extract dominant colors using colorgram"""
try:
colors = colorgram.extract(image_path, num_colors)
palette = {}
for i, color in enumerate(colors):
# Determine semantic color role based on proportion
if i == 0 and color.proportion > 0.3:
name = "background"
elif i == 1:
name = "primary"
elif i == 2:
name = "secondary"
else:
name = f"accent-{i-2}"
palette[name] = {
"hex": f"#{color.rgb.r:02x}{color.rgb.g:02x}{color.rgb.b:02x}",
"rgb": f"rgb({color.rgb.r}, {color.rgb.g}, {color.rgb.b})",
"proportion": round(color.proportion, 3)
}
return palette
except Exception as e:
print(f"Error extracting colors: {e}")
return self._get_default_colors()
def detect_spacing(self, image):
"""Analyze spacing patterns using OpenCV"""
try:
gray = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
edges = cv2.Canny(gray, 50, 150)
# Find contours for element detection
contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# Calculate spacing between elements
bounding_boxes = [cv2.boundingRect(c) for c in contours if cv2.contourArea(c) > 100]
if len(bounding_boxes) > 1:
# Sort by y-coordinate to find vertical spacing
bounding_boxes.sort(key=lambda x: x[1])
vertical_gaps = []
for i in range(len(bounding_boxes)-1):
gap = bounding_boxes[i+1][1] - (bounding_boxes[i][1] + bounding_boxes[i][3])
if gap > 0:
vertical_gaps.append(gap)
# Find common spacing values using clustering
spacing_system = self._cluster_spacing_values(vertical_gaps)
return spacing_system
except Exception as e:
print(f"Error detecting spacing: {e}")
return {"small": "8px", "medium": "16px", "large": "32px"} # Defaults
def _cluster_spacing_values(self, gaps):
"""Group similar spacing values"""
if not gaps:
return {"small": "8px", "medium": "16px", "large": "32px"}
gaps.sort()
# Simple clustering for common spacing values
unique_gaps = list(set(gaps))
if len(unique_gaps) >= 3:
return {
"small": f"{unique_gaps[0]}px",
"medium": f"{unique_gaps[len(unique_gaps)//2]}px",
"large": f"{unique_gaps[-1]}px"
}
elif len(unique_gaps) == 2:
return {
"small": f"{unique_gaps[0]}px",
"large": f"{unique_gaps[1]}px"
}
return {"base": f"{unique_gaps[0]}px" if unique_gaps else "16px"}
def analyze_components(self, image):
"""Use Pix2Struct for component understanding"""
if self.pix2struct_model is None or self.pix2struct_processor is None:
# Fallback if model loading failed
return {
"detected_elements": "Model not available - basic extraction only",
"layout": "responsive"
}
try:
inputs = self.pix2struct_processor(images=image, return_tensors="pt")
with torch.no_grad():
generated_ids = self.pix2struct_model.generate(**inputs, max_length=100)
description = self.pix2struct_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
# Parse description for component types
components = {
"detected_elements": description,
"layout": "responsive" if "responsive" in description.lower() else "fixed"
}
return components
except Exception as e:
print(f"Error analyzing components: {e}")
return {
"detected_elements": "Error during analysis",
"layout": "responsive"
}
def detect_typography(self, image):
"""Basic typography detection"""
# Simplified typography detection without EasyOCR for initial implementation
return {
"heading": {
"family": "sans-serif",
"size": "32px",
"weight": "700"
},
"body": {
"family": "sans-serif",
"size": "16px",
"weight": "400"
},
"caption": {
"family": "sans-serif",
"size": "14px",
"weight": "400"
}
}
def _get_default_colors(self):
"""Return default color palette"""
return {
"primary": {"hex": "#3B82F6", "rgb": "rgb(59, 130, 246)", "proportion": 0.25},
"secondary": {"hex": "#8B5CF6", "rgb": "rgb(139, 92, 246)", "proportion": 0.15},
"background": {"hex": "#FFFFFF", "rgb": "rgb(255, 255, 255)", "proportion": 0.40},
"text": {"hex": "#1F2937", "rgb": "rgb(31, 41, 55)", "proportion": 0.20}
}
def resize_for_processing(self, image, max_dimension=1024):
"""Resize large images while maintaining aspect ratio"""
if max(image.size) > max_dimension:
ratio = max_dimension / max(image.size)
new_size = tuple(int(dim * ratio) for dim in image.size)
return image.resize(new_size, Image.Resampling.LANCZOS)
return image |