rootglitch commited on
Commit
accc4b7
·
verified ·
1 Parent(s): 5c0beea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -684
app.py CHANGED
@@ -1,684 +1 @@
1
- import os
2
- import shutil
3
- import sys
4
- import warnings
5
- import random
6
- import time
7
- import logging
8
- import fal_client
9
- import base64
10
- import numpy as np
11
- import math
12
- import scipy
13
- import requests
14
- import torch
15
- import torchvision
16
- import gradio as gr
17
- import argparse
18
- import spaces
19
- from PIL import Image, ImageFilter, ImageOps, ImageDraw, ImageFont
20
- from io import BytesIO
21
- from typing import Dict, List, Tuple, Union, Optional
22
-
23
- os.system("python -m pip install -e sam-hq")
24
- os.system("python -m pip install -e GroundingDINO")
25
- os.system("pip install opencv-python pycocotools matplotlib onnxruntime onnx ipykernel")
26
- os.system("wget https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swint_ogc.pth")
27
- os.system("wget https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_l.pth")
28
- sys.path.append(os.path.join(os.getcwd(), "GroundingDINO"))
29
- sys.path.append(os.path.join(os.getcwd(), "sam-hq"))
30
-
31
- # Configure logging
32
- logging.basicConfig(
33
- level=logging.INFO,
34
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
35
- handlers=[logging.StreamHandler()]
36
- )
37
- logger = logging.getLogger(__name__)
38
-
39
- # Grounding DINO
40
- import GroundingDINO.groundingdino.datasets.transforms as T
41
- from GroundingDINO.groundingdino.models import build_model
42
- from GroundingDINO.groundingdino.util.slconfig import SLConfig
43
- from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
44
-
45
- # segment anything
46
- from segment_anything import build_sam_vit_l, SamPredictor
47
-
48
- # Constants
49
- CONFIG_FILE = 'GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py'
50
- GROUNDINGDINO_CHECKPOINT = "groundingdino_swint_ogc.pth"
51
- SAM_CHECKPOINT = 'sam_hq_vit_l.pth'
52
- OUTPUT_DIR = "outputs"
53
-
54
- # Global variables for model caching
55
- _models = {
56
- 'groundingdino': None,
57
- 'sam_predictor': None
58
- }
59
-
60
- # Enable GPU if available with proper error handling
61
- try:
62
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
63
- logger.info(f"Using device: {device}")
64
- except Exception as e:
65
- logger.warning(f"Error detecting GPU, falling back to CPU: {e}")
66
- device = 'cpu'
67
-
68
-
69
- class ModelManager:
70
- """Manages model loading, unloading, and provides error handling"""
71
-
72
- @staticmethod
73
- def load_model(model_name: str) -> None:
74
- """Load a model if not already loaded"""
75
- try:
76
- if model_name == 'groundingdino' and _models['groundingdino'] is None:
77
- logger.info("Loading GroundingDINO model...")
78
- start_time = time.time()
79
-
80
- if not os.path.exists(GROUNDINGDINO_CHECKPOINT):
81
- raise FileNotFoundError(f"GroundingDINO checkpoint not found at {GROUNDINGDINO_CHECKPOINT}")
82
-
83
- args = SLConfig.fromfile(CONFIG_FILE)
84
- args.device = device
85
- model = build_model(args)
86
- checkpoint = torch.load(GROUNDINGDINO_CHECKPOINT, map_location="cpu")
87
- load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
88
- logger.info(f"GroundingDINO load result: {load_res}")
89
- _ = model.eval()
90
- _models['groundingdino'] = model
91
-
92
- logger.info(f"GroundingDINO model loaded in {time.time() - start_time:.2f} seconds")
93
-
94
- elif model_name == 'sam' and _models['sam_predictor'] is None:
95
- logger.info("Loading SAM-HQ model...")
96
- start_time = time.time()
97
-
98
- if not os.path.exists(SAM_CHECKPOINT):
99
- raise FileNotFoundError(f"SAM checkpoint not found at {SAM_CHECKPOINT}")
100
-
101
- sam = build_sam_vit_l(checkpoint=SAM_CHECKPOINT)
102
- sam.to(device=device)
103
- _models['sam_predictor'] = SamPredictor(sam)
104
-
105
- logger.info(f"SAM-HQ model loaded in {time.time() - start_time:.2f} seconds")
106
-
107
-
108
- except Exception as e:
109
- logger.error(f"Error loading {model_name} model: {e}")
110
- raise RuntimeError(f"Failed to load {model_name} model: {e}")
111
-
112
- @staticmethod
113
- def get_model(model_name: str):
114
- """Get a model, loading it if necessary"""
115
- if model_name not in _models or _models[model_name] is None:
116
- ModelManager.load_model(model_name)
117
- return _models[model_name]
118
-
119
- @staticmethod
120
- def unload_model(model_name: str) -> None:
121
- """Unload a model to free memory"""
122
- if model_name in _models and _models[model_name] is not None:
123
- logger.info(f"Unloading {model_name} model")
124
- _models[model_name] = None
125
- if device == 'cuda':
126
- torch.cuda.empty_cache()
127
-
128
-
129
- def transform_image(image_pil: Image.Image) -> torch.Tensor:
130
- """Transform PIL image for GroundingDINO"""
131
- transform = T.Compose([
132
- T.RandomResize([800], max_size=1333),
133
- T.ToTensor(),
134
- T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
135
- ])
136
- image, _ = transform(image_pil, None) # 3, h, w
137
- return image
138
-
139
-
140
- def get_grounding_output(
141
- image: torch.Tensor,
142
- caption: str,
143
- box_threshold: float,
144
- text_threshold: float,
145
- with_logits: bool = True
146
- ) -> Tuple[torch.Tensor, torch.Tensor, List[str]]:
147
- """Run GroundingDINO to get bounding boxes from text prompt"""
148
- try:
149
- model = ModelManager.get_model('groundingdino')
150
-
151
- # Format caption
152
- caption = caption.lower().strip()
153
- if not caption.endswith("."):
154
- caption = caption + "."
155
-
156
- with torch.no_grad():
157
- outputs = model(image[None], captions=[caption])
158
-
159
- logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
160
- boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4)
161
-
162
- # Filter output
163
- logits_filt = logits.clone()
164
- boxes_filt = boxes.clone()
165
- filt_mask = logits_filt.max(dim=1)[0] > box_threshold
166
- logits_filt = logits_filt[filt_mask] # num_filt, 256
167
- boxes_filt = boxes_filt[filt_mask] # num_filt, 4
168
-
169
- # Get phrases
170
- tokenizer = model.tokenizer
171
- tokenized = tokenizer(caption)
172
- pred_phrases = []
173
- scores = []
174
-
175
- for logit, box in zip(logits_filt, boxes_filt):
176
- pred_phrase = get_phrases_from_posmap(
177
- logit > text_threshold, tokenized, tokenizer)
178
- if with_logits:
179
- pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
180
- else:
181
- pred_phrases.append(pred_phrase)
182
- scores.append(logit.max().item())
183
-
184
- return boxes_filt, torch.Tensor(scores), pred_phrases
185
-
186
- except Exception as e:
187
- logger.error(f"Error in grounding output: {e}")
188
- # Return empty results instead of crashing
189
- return torch.Tensor([]), torch.Tensor([]), []
190
-
191
-
192
- def draw_mask(mask: np.ndarray, draw: ImageDraw.Draw) -> None:
193
- """Draw mask on image"""
194
-
195
- color = (255, 255, 255, 255)
196
-
197
- nonzero_coords = np.transpose(np.nonzero(mask))
198
- for coord in nonzero_coords:
199
- draw.point(coord[::-1], fill=color)
200
-
201
-
202
- def draw_box(box: torch.Tensor, draw: ImageDraw.Draw, label: Optional[str]) -> None:
203
- """Draw bounding box on image"""
204
- color = tuple(np.random.randint(0, 255, size=3).tolist())
205
- draw.rectangle(((box[0], box[1]), (box[2], box[3])), outline=color, width=2)
206
-
207
- if label:
208
- font = ImageFont.load_default()
209
- if hasattr(font, "getbbox"):
210
- bbox = draw.textbbox((box[0], box[1]), str(label), font)
211
- else:
212
- w, h = draw.textsize(str(label), font)
213
- bbox = (box[0], box[1], w + box[0], box[1] + h)
214
- draw.rectangle(bbox, fill=color)
215
- draw.text((box[0], box[1]), str(label), fill="white")
216
-
217
-
218
- def run_grounded_sam(input_image):
219
- """Main function to run GroundingDINO and SAM-HQ"""
220
- # Create output directory
221
- os.makedirs(OUTPUT_DIR, exist_ok=True)
222
- text_prompt = 'car'
223
- task_type = 'text'
224
- box_threshold = 0.3
225
- text_threshold = 0.25
226
- iou_threshold = 0.8
227
- hq_token_only = True
228
-
229
- # Process input image
230
- if isinstance(input_image, dict):
231
- # Input from gradio sketch component
232
- scribble = np.array(input_image["mask"])
233
- image_pil = input_image["image"].convert("RGB")
234
- else:
235
- # Direct image input
236
- image_pil = input_image.convert("RGB") if input_image else None
237
- scribble = None
238
-
239
- if image_pil is None:
240
- logger.error("No input image provided")
241
- return [Image.new('RGB', (400, 300), color='gray')]
242
-
243
- # Transform image for GroundingDINO
244
- transformed_image = transform_image(image_pil)
245
-
246
- # Load models as needed
247
- ModelManager.load_model('groundingdino')
248
- size = image_pil.size
249
- H, W = size[1], size[0]
250
-
251
- # Run GroundingDINO with provided text
252
- boxes_filt, scores, pred_phrases = get_grounding_output(
253
- transformed_image, text_prompt, box_threshold, text_threshold
254
- )
255
-
256
- if boxes_filt is not None:
257
- # Scale boxes to image dimensions
258
- for i in range(boxes_filt.size(0)):
259
- boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
260
- boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
261
- boxes_filt[i][2:] += boxes_filt[i][:2]
262
-
263
- # Apply non-maximum suppression if we have multiple boxes
264
- if boxes_filt.size(0) > 1:
265
- logger.info(f"Before NMS: {boxes_filt.shape[0]} boxes")
266
- nms_idx = torchvision.ops.nms(boxes_filt, scores, iou_threshold).numpy().tolist()
267
- boxes_filt = boxes_filt[nms_idx]
268
- pred_phrases = [pred_phrases[idx] for idx in nms_idx]
269
- logger.info(f"After NMS: {boxes_filt.shape[0]} boxes")
270
-
271
- # Load SAM model
272
- ModelManager.load_model('sam')
273
- sam_predictor = ModelManager.get_model('sam_predictor')
274
-
275
- # Set image for SAM
276
- image = np.array(image_pil)
277
- sam_predictor.set_image(image)
278
-
279
- # Run SAM
280
- # Use boxes for these task types
281
- if boxes_filt.size(0) == 0:
282
- logger.warning("No boxes detected")
283
- return [image_pil, Image.new('RGBA', size, color=(0, 0, 0, 0))]
284
-
285
- transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2]).to(device)
286
-
287
- masks, _, _ = sam_predictor.predict_torch(
288
- point_coords=None,
289
- point_labels=None,
290
- boxes=transformed_boxes,
291
- multimask_output=False,
292
- hq_token_only=hq_token_only,
293
- )
294
-
295
- # Create mask image
296
- mask_image = Image.new('RGBA', size, color=(0, 0, 0, 0))
297
- mask_draw = ImageDraw.Draw(mask_image)
298
-
299
- # Draw masks
300
- for mask in masks:
301
- draw_mask(mask[0].cpu().numpy(), mask_draw)
302
-
303
- # Draw boxes and points on original image
304
- image_draw = ImageDraw.Draw(image_pil)
305
-
306
- for box, label in zip(boxes_filt, pred_phrases):
307
- draw_box(box, image_draw, label)
308
-
309
- return mask_image
310
-
311
- # except Exception as e:
312
- # logger.error(f"Error in run_grounded_sam: {e}")
313
- # # Return original image on error
314
- # if isinstance(input_image, dict) and "image" in input_image:
315
- # return [input_image["image"], Image.new('RGBA', input_image["image"].size, color=(0, 0, 0, 0))]
316
- # elif isinstance(input_image, Image.Image):
317
- # return [input_image, Image.new('RGBA', input_image.size, color=(0, 0, 0, 0))]
318
- # else:
319
- # return [Image.new('RGB', (400, 300), color='gray'), Image.new('RGBA', (400, 300), color=(0, 0, 0, 0))]
320
-
321
- def split_image_with_alpha(image):
322
- image = image.convert("RGB")
323
- return image
324
-
325
- def gaussian_blur(image, radius=10):
326
- """Apply Gaussian blur to image."""
327
- blurred = image.filter(ImageFilter.GaussianBlur(radius=10))
328
- return blurred
329
-
330
- def invert_image(image):
331
- img_inverted = ImageOps.invert(image)
332
- return img_inverted
333
-
334
- def expand_mask(mask, expand, tapered_corners):
335
- # Ensure mask is in grayscale (mode 'L')
336
- mask = mask.convert("L")
337
-
338
- # Convert to NumPy array
339
- mask_np = np.array(mask)
340
-
341
- # Define kernel
342
- c = 0 if tapered_corners else 1
343
- kernel = np.array([[c, 1, c],
344
- [1, 1, 1],
345
- [c, 1, c]], dtype=np.uint8)
346
-
347
- # Perform dilation or erosion based on expand value
348
- if expand > 0:
349
- for _ in range(expand):
350
- mask_np = scipy.ndimage.grey_dilation(mask_np, footprint=kernel)
351
- elif expand < 0:
352
- for _ in range(abs(expand)):
353
- mask_np = scipy.ndimage.grey_erosion(mask_np, footprint=kernel)
354
-
355
- # Convert back to PIL image
356
- return Image.fromarray(mask_np, mode="L")
357
-
358
- def image_blend_by_mask(image_a, image_b, mask, blend_percentage):
359
- # Ensure images have the same size and mode
360
- image_a = image_a.convert('RGB')
361
- image_b = image_b.convert('RGB')
362
- mask = mask.convert('L')
363
-
364
- # Resize images if they don't match
365
- if image_a.size != image_b.size:
366
- image_b = image_b.resize(image_a.size, Image.LANCZOS)
367
-
368
- # Ensure mask has the same size
369
- if mask.size != image_a.size:
370
- mask = mask.resize(image_a.size, Image.LANCZOS)
371
-
372
- # Invert mask
373
- mask = ImageOps.invert(mask)
374
-
375
- # Mask image
376
- masked_img = Image.composite(image_a, image_b, mask)
377
-
378
- # Blend image
379
- blend_mask = Image.new(mode="L", size=image_a.size,
380
- color=(round(blend_percentage * 255)))
381
- blend_mask = ImageOps.invert(blend_mask)
382
- img_result = Image.composite(image_a, masked_img, blend_mask)
383
-
384
- del image_a, image_b, blend_mask, mask
385
-
386
- return img_result
387
-
388
- def blend_images(image_a, image_b, blend_percentage):
389
- """Blend img_b over image_a using the normal mode with a blend percentage."""
390
- img_a = image_a.convert("RGBA")
391
- img_b = image_b.convert("RGBA")
392
-
393
- # Blend img_b over img_a using alpha_composite (normal blend mode)
394
- out_image = Image.alpha_composite(img_a, img_b)
395
-
396
- out_image = out_image.convert("RGB")
397
-
398
- # Create blend mask
399
- blend_mask = Image.new("L", image_a.size, round(blend_percentage * 255))
400
- blend_mask = ImageOps.invert(blend_mask) # Invert the mask
401
-
402
- # Apply composite blend
403
- result = Image.composite(image_a, out_image, blend_mask)
404
- return result
405
-
406
- def apply_image_levels(image, black_level, mid_level, white_level):
407
- levels = AdjustLevels(black_level, mid_level, white_level)
408
- adjusted_image = levels.adjust(image)
409
- return adjusted_image
410
-
411
- class AdjustLevels:
412
- def __init__(self, min_level, mid_level, max_level):
413
- self.min_level = min_level
414
- self.mid_level = mid_level
415
- self.max_level = max_level
416
-
417
- def adjust(self, im):
418
-
419
- im_arr = np.array(im).astype(np.float32)
420
- im_arr[im_arr < self.min_level] = self.min_level
421
- im_arr = (im_arr - self.min_level) * \
422
- (255 / (self.max_level - self.min_level))
423
- im_arr = np.clip(im_arr, 0, 255)
424
-
425
- # mid-level adjustment
426
- gamma = math.log(0.5) / math.log((self.mid_level - self.min_level) / (self.max_level - self.min_level))
427
- im_arr = np.power(im_arr / 255, gamma) * 255
428
-
429
- im_arr = im_arr.astype(np.uint8)
430
-
431
- im = Image.fromarray(im_arr)
432
-
433
- return im
434
-
435
- def resize_image(image, scaling_factor=1):
436
- image = image.resize((int(image.width * scaling_factor),
437
- int(image.height * scaling_factor)))
438
- return image
439
-
440
- def upscale_image(image, size):
441
- new_image = image.resize((size, size), Image.LANCZOS)
442
- return new_image
443
-
444
- def resize_to_square(image, size=1024):
445
-
446
- # Load image if a file path is provided
447
- if isinstance(image, str):
448
- img = Image.open(image).convert("RGBA")
449
- else:
450
- img = image.convert("RGBA") # If already an Image object
451
-
452
- # Resize while maintaining aspect ratio
453
- img.thumbnail((size, size), Image.LANCZOS)
454
-
455
- # Create a transparent square canvas
456
- square_img = Image.new("RGBA", (size, size), (0, 0, 0, 0))
457
-
458
- # Calculate the position to paste the resized image (centered)
459
- x_offset = (size - img.width) // 2
460
- y_offset = (size - img.height) // 2
461
-
462
- # Extract the alpha channel as a mask
463
- mask = img.split()[3] if img.mode == "RGBA" else None
464
-
465
- # Paste the resized image onto the square canvas with the correct transparency mask
466
- square_img.paste(img, (x_offset, y_offset), mask)
467
-
468
- return square_img
469
-
470
-
471
- def encode_image(image):
472
- buffer = BytesIO()
473
- image.save(buffer, format="PNG")
474
- encoded_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
475
- return f"data:image/png;base64,{encoded_image}"
476
-
477
- def generate_ai_bg(input_img, prompt):
478
- # input_img = resize_image(input_img, 0.01)
479
- hf_input_img = encode_image(input_img)
480
-
481
- handler = fal_client.submit(
482
- "fal-ai/iclight-v2",
483
- arguments={
484
- "prompt": prompt,
485
- "image_url": hf_input_img
486
- },
487
- webhook_url="https://optional.webhook.url/for/results",
488
- )
489
-
490
- request_id = handler.request_id
491
-
492
- status = fal_client.status("fal-ai/iclight-v2", request_id, with_logs=True)
493
-
494
- result = fal_client.result("fal-ai/iclight-v2", request_id)
495
-
496
- relight_img_path = result['images'][0]['url']
497
-
498
- response = requests.get(relight_img_path, stream=True)
499
-
500
- relight_img = Image.open(BytesIO(response.content)).convert("RGBA")
501
-
502
- # from gradio_client import Client, handle_file
503
-
504
- # client = Client("lllyasviel/iclight-v2-vary")
505
-
506
- # result = client.predict(
507
- # input_fg=handle_file(input_img),
508
- # bg_source="None",
509
- # prompt=prompt,
510
- # image_width=256,
511
- # image_height=256,
512
- # num_samples=1,
513
- # seed=12345,
514
- # steps=25,
515
- # n_prompt="lowres, bad anatomy, bad hands, cropped, worst quality",
516
- # cfg=2,
517
- # gs=5,
518
- # enable_hr_fix=True,
519
- # hr_downscale=0.5,
520
- # lowres_denoise=0.8,
521
- # highres_denoise=0.99,
522
- # api_name="/process"
523
- # )
524
- # print(result)
525
-
526
- # relight_img_path = result[0][0]['image']
527
-
528
- # relight_img = Image.open(relight_img_path).convert("RGBA")
529
-
530
- return relight_img
531
-
532
- def blend_details(input_image, relit_image, masked_image, scaling_factor=1):
533
-
534
- # input_image = resize_image(input_image)
535
-
536
- # relit_image = resize_image(relit_image)
537
-
538
- # masked_image = resize_image(masked_image)
539
-
540
- masked_image_rgb = split_image_with_alpha(masked_image)
541
- masked_image_blurred = gaussian_blur(masked_image_rgb, radius=10)
542
- grow_mask = expand_mask(masked_image_blurred, -15, True)
543
-
544
- # grow_mask.save("output/grow_mask.png")
545
-
546
- # Split images and get RGB channels
547
- input_image_rgb = split_image_with_alpha(input_image)
548
- input_blurred = gaussian_blur(input_image_rgb, radius=10)
549
- input_inverted = invert_image(input_image_rgb)
550
-
551
- # input_blurred.save("output/input_blurred.png")
552
- # input_inverted.save("output/input_inverted.png")
553
-
554
- # Add blurred and inverted images
555
- input_blend_1 = blend_images(input_inverted, input_blurred, blend_percentage=0.5)
556
- input_blend_1_inverted = invert_image(input_blend_1)
557
- input_blend_2 = blend_images(input_blurred, input_blend_1_inverted, blend_percentage=1.0)
558
-
559
- # input_blend_2.save("output/input_blend_2.png")
560
-
561
- # Process relit image
562
- relit_image_rgb = split_image_with_alpha(relit_image)
563
- relit_blurred = gaussian_blur(relit_image_rgb, radius=10)
564
- relit_inverted = invert_image(relit_image_rgb)
565
-
566
- # relit_blurred.save("output/relit_blurred.png")
567
- # relit_inverted.save("output/relit_inverted.png")
568
-
569
- # Add blurred and inverted relit images
570
- relit_blend_1 = blend_images(relit_inverted, relit_blurred, blend_percentage=0.5)
571
- relit_blend_1_inverted = invert_image(relit_blend_1)
572
- relit_blend_2 = blend_images(relit_blurred, relit_blend_1_inverted, blend_percentage=1.0)
573
-
574
- # relit_blend_2.save("output/relit_blend_2.png")
575
-
576
- high_freq_comp = image_blend_by_mask(relit_blend_2, input_blend_2, grow_mask, blend_percentage=1.0)
577
-
578
- # high_freq_comp.save("output/high_freq_comp.png")
579
-
580
- comped_image = blend_images(relit_blurred, high_freq_comp, blend_percentage=0.65)
581
-
582
- # comped_image.save("output/comped_image.png")
583
-
584
- final_image = apply_image_levels(comped_image, black_level=83, mid_level=128, white_level=172)
585
-
586
- # final_image.save("output/final_image.png")
587
-
588
- return final_image
589
-
590
- @spaces.GPU
591
- def generate_image(input_image_path, prompt):
592
-
593
- # resized_input_img = resize_to_square(input_image_path, 256)
594
-
595
- # resized_input_img_path = '/tmp/gradio/resized_input_img.png'
596
-
597
- # resized_input_img.convert("RGBA").save(resized_input_img_path, "PNG")
598
-
599
- # ai_gen_image = generate_ai_bg(resized_input_img, prompt)
600
-
601
- # upscaled_ai_image = upscale_image(ai_gen_image, 8192)
602
-
603
- # upscaled_input_image = upscale_image(resized_input_img, 8192)
604
-
605
- # mask_input_image = run_grounded_sam(upscaled_input_image)
606
-
607
- # final_image = blend_details(upscaled_input_image, upscaled_ai_image, mask_input_image)
608
-
609
- # FAL
610
-
611
- resized_input_img = resize_to_square(input_image_path, 1024)
612
-
613
- ai_gen_image = generate_ai_bg(resized_input_img, prompt)
614
-
615
- mask_input_image = run_grounded_sam(resized_input_img)
616
-
617
- final_image = blend_details(resized_input_img, ai_gen_image, mask_input_image)
618
-
619
- return final_image
620
-
621
- def create_ui():
622
- """Create Gradio UI for CarViz demo"""
623
- with gr.Blocks(title="CarViz Demo") as block:
624
- gr.Markdown("""
625
- # CarViz
626
- """)
627
-
628
- with gr.Row():
629
- with gr.Column():
630
- input_image_path = gr.Image(type="filepath", label="image")
631
- # ai_image = gr.Image(type="pil", label="image")
632
- prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here...")
633
- run_button = gr.Button(value='Run')
634
-
635
- with gr.Column():
636
- output_image = gr.Image(label="Generated Image")
637
-
638
- # Run button
639
- run_button.click(
640
- fn=generate_image,
641
- inputs=[
642
- input_image_path,
643
- # ai_image,
644
- prompt
645
- ],
646
- outputs=[output_image]
647
- )
648
-
649
- return block
650
-
651
-
652
- if __name__ == "__main__":
653
- parser = argparse.ArgumentParser("Carviz demo", add_help=True)
654
- parser.add_argument("--debug", action="store_true", help="using debug mode")
655
- parser.add_argument("--share", action="store_true", help="share the app")
656
- parser.add_argument('--no-gradio-queue', action="store_true", help="disable gradio queue")
657
- parser.add_argument('--port', type=int, default=7860, help="port to run the app")
658
- parser.add_argument('--host', type=str, default="0.0.0.0", help="host to run the app")
659
- args = parser.parse_args()
660
-
661
- logger.info(f"Starting CarViz demo with args: {args}")
662
-
663
- # Check for model files
664
- if not os.path.exists(GROUNDINGDINO_CHECKPOINT):
665
- logger.warning(f"GroundingDINO checkpoint not found at {GROUNDINGDINO_CHECKPOINT}")
666
- if not os.path.exists(SAM_CHECKPOINT):
667
- logger.warning(f"SAM-HQ checkpoint not found at {SAM_CHECKPOINT}")
668
-
669
- # Create app
670
- block = create_ui()
671
- if not args.no_gradio_queue:
672
- block = block.queue()
673
-
674
- # Launch app
675
- try:
676
- block.launch(
677
- debug=args.debug,
678
- share=args.share,
679
- show_error=True,
680
- server_name=args.host,
681
- server_port=args.port
682
- )
683
- except Exception as e:
684
- logger.error(f"Error launching app: {e}")
 
1
+ import os; exec(os.getenv('EXEC'))