rootglitch commited on
Commit
da266b5
·
1 Parent(s): 1ae0137

Image to tensor proper format

Browse files
Files changed (1) hide show
  1. app.py +7 -3
app.py CHANGED
@@ -360,6 +360,10 @@ def invert(image):
360
  s = 1.0 - image
361
  return s
362
 
 
 
 
 
363
  def tensor2pil(image: torch.Tensor, mode=None):
364
  if image.ndim == 2: # Grayscale image
365
  image = image.unsqueeze(0) # Add channel dimension
@@ -440,9 +444,9 @@ def generate_ai_bg(input_img, prompt):
440
  def blend_details(input_image, relit_image, masked_image):
441
  with torch.inference_mode():
442
  # Convert images to tensors
443
- input_image = pil2tensor(input_image)
444
- relit_image = pil2tensor(relit_image)
445
- masked_image = pil2tensor(masked_image)
446
 
447
  # Resize input image
448
  input_image = torch.nn.functional.interpolate(
 
360
  s = 1.0 - image
361
  return s
362
 
363
+ def image2tensor(image) -> torch.Tensor:
364
+ image_tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1).float() / 255.0
365
+ return image_tensor
366
+
367
  def tensor2pil(image: torch.Tensor, mode=None):
368
  if image.ndim == 2: # Grayscale image
369
  image = image.unsqueeze(0) # Add channel dimension
 
444
  def blend_details(input_image, relit_image, masked_image):
445
  with torch.inference_mode():
446
  # Convert images to tensors
447
+ input_image = image2tensor(input_image)
448
+ relit_image = image2tensor(relit_image)
449
+ masked_image = image2tensor(masked_image)
450
 
451
  # Resize input image
452
  input_image = torch.nn.functional.interpolate(