Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	
		Linoy Tsaban
		
	commited on
		
		
					Commit 
							
							·
						
						d19d91b
	
1
								Parent(s):
							
							5d7ba0f
								
Update utils.py
Browse files
    	
        utils.py
    CHANGED
    
    | @@ -3,7 +3,7 @@ from PIL import Image, ImageDraw ,ImageFont | |
| 3 | 
             
            from matplotlib import pyplot as plt
         | 
| 4 | 
             
            import torchvision.transforms as T
         | 
| 5 | 
             
            import os
         | 
| 6 | 
            -
            import torch | 
| 7 | 
             
            import yaml
         | 
| 8 |  | 
| 9 | 
             
            def show_torch_img(img):
         | 
| @@ -20,14 +20,14 @@ def tensor_to_pil(tensor_imgs): | |
| 20 | 
             
                    tensor_imgs = torch.cat(tensor_imgs)
         | 
| 21 | 
             
                tensor_imgs = (tensor_imgs / 2 + 0.5).clamp(0, 1)
         | 
| 22 | 
             
                to_pil = T.ToPILImage()
         | 
| 23 | 
            -
                pil_imgs = [to_pil(img) for img in tensor_imgs] | 
| 24 | 
             
                return pil_imgs
         | 
| 25 |  | 
| 26 | 
             
            def pil_to_tensor(pil_imgs):
         | 
| 27 | 
             
                to_torch = T.ToTensor()
         | 
| 28 | 
             
                if type(pil_imgs) == PIL.Image.Image:
         | 
| 29 | 
             
                    tensor_imgs = to_torch(pil_imgs).unsqueeze(0)*2-1
         | 
| 30 | 
            -
                elif type(pil_imgs) == list: | 
| 31 | 
             
                    tensor_imgs = torch.cat([to_torch(pil_imgs).unsqueeze(0)*2-1 for img in pil_imgs]).to(device)
         | 
| 32 | 
             
                else:
         | 
| 33 | 
             
                    raise Exception("Input need to be PIL.Image or list of PIL.Image")
         | 
| @@ -40,30 +40,30 @@ def pil_to_tensor(pil_imgs): | |
| 40 | 
             
            # num_col = n // num_rows
         | 
| 41 | 
             
            # num_col  = num_col + 1 if n % num_rows else num_col
         | 
| 42 | 
             
            # num_col
         | 
| 43 | 
            -
            def add_margin(pil_img, top = 0, right = 0, bottom = 0, | 
| 44 | 
             
                                left = 0, color = (255,255,255)):
         | 
| 45 | 
             
                width, height = pil_img.size
         | 
| 46 | 
             
                new_width = width + right + left
         | 
| 47 | 
             
                new_height = height + top + bottom
         | 
| 48 | 
             
                result = Image.new(pil_img.mode, (new_width, new_height), color)
         | 
| 49 | 
            -
             | 
| 50 | 
             
                result.paste(pil_img, (left, top))
         | 
| 51 | 
             
                return result
         | 
| 52 |  | 
| 53 | 
            -
            def image_grid(imgs, rows = 1, cols = None, | 
| 54 | 
             
                                size = None,
         | 
| 55 | 
             
                               titles = None, text_pos = (0, 0)):
         | 
| 56 | 
             
                if type(imgs) == list and type(imgs[0]) == torch.Tensor:
         | 
| 57 | 
             
                    imgs = torch.cat(imgs)
         | 
| 58 | 
             
                if type(imgs) == torch.Tensor:
         | 
| 59 | 
             
                    imgs = tensor_to_pil(imgs)
         | 
| 60 | 
            -
             | 
| 61 | 
             
                if not size is None:
         | 
| 62 | 
             
                    imgs = [img.resize((size,size)) for img in imgs]
         | 
| 63 | 
             
                if cols is None:
         | 
| 64 | 
             
                    cols = len(imgs)
         | 
| 65 | 
             
                assert len(imgs) >= rows*cols
         | 
| 66 | 
            -
             | 
| 67 | 
             
                top=20
         | 
| 68 | 
             
                w, h = imgs[0].size
         | 
| 69 | 
             
                delta = 0
         | 
| @@ -71,23 +71,23 @@ def image_grid(imgs, rows = 1, cols = None, | |
| 71 | 
             
                    delta = top
         | 
| 72 | 
             
                    h = imgs[1].size[1]
         | 
| 73 | 
             
                if not titles is  None:
         | 
| 74 | 
            -
                    font = ImageFont.truetype("/usr/share/fonts/truetype/freefont/FreeMono.ttf", | 
| 75 | 
             
                                                size = 20, encoding="unic")
         | 
| 76 | 
            -
                    h = top + h | 
| 77 | 
            -
                grid = Image.new('RGB', size=(cols*w, rows*h+delta)) | 
| 78 | 
             
                for i, img in enumerate(imgs):
         | 
| 79 | 
            -
             | 
| 80 | 
             
                    if not titles is  None:
         | 
| 81 | 
             
                        img = add_margin(img, top = top, bottom = 0,left=0)
         | 
| 82 | 
             
                        draw = ImageDraw.Draw(img)
         | 
| 83 | 
            -
                        draw.text(text_pos, titles[i],(0,0,0), | 
| 84 | 
             
                        font = font)
         | 
| 85 | 
             
                    if not delta == 0 and i > 0:
         | 
| 86 | 
             
                       grid.paste(img, box=(i%cols*w, i//cols*h+delta))
         | 
| 87 | 
             
                    else:
         | 
| 88 | 
             
                        grid.paste(img, box=(i%cols*w, i//cols*h))
         | 
| 89 | 
            -
             | 
| 90 | 
            -
                return grid | 
| 91 |  | 
| 92 |  | 
| 93 | 
             
            """
         | 
|  | |
| 3 | 
             
            from matplotlib import pyplot as plt
         | 
| 4 | 
             
            import torchvision.transforms as T
         | 
| 5 | 
             
            import os
         | 
| 6 | 
            +
            import torch
         | 
| 7 | 
             
            import yaml
         | 
| 8 |  | 
| 9 | 
             
            def show_torch_img(img):
         | 
|  | |
| 20 | 
             
                    tensor_imgs = torch.cat(tensor_imgs)
         | 
| 21 | 
             
                tensor_imgs = (tensor_imgs / 2 + 0.5).clamp(0, 1)
         | 
| 22 | 
             
                to_pil = T.ToPILImage()
         | 
| 23 | 
            +
                pil_imgs = [to_pil(img) for img in tensor_imgs]
         | 
| 24 | 
             
                return pil_imgs
         | 
| 25 |  | 
| 26 | 
             
            def pil_to_tensor(pil_imgs):
         | 
| 27 | 
             
                to_torch = T.ToTensor()
         | 
| 28 | 
             
                if type(pil_imgs) == PIL.Image.Image:
         | 
| 29 | 
             
                    tensor_imgs = to_torch(pil_imgs).unsqueeze(0)*2-1
         | 
| 30 | 
            +
                elif type(pil_imgs) == list:
         | 
| 31 | 
             
                    tensor_imgs = torch.cat([to_torch(pil_imgs).unsqueeze(0)*2-1 for img in pil_imgs]).to(device)
         | 
| 32 | 
             
                else:
         | 
| 33 | 
             
                    raise Exception("Input need to be PIL.Image or list of PIL.Image")
         | 
|  | |
| 40 | 
             
            # num_col = n // num_rows
         | 
| 41 | 
             
            # num_col  = num_col + 1 if n % num_rows else num_col
         | 
| 42 | 
             
            # num_col
         | 
| 43 | 
            +
            def add_margin(pil_img, top = 0, right = 0, bottom = 0,
         | 
| 44 | 
             
                                left = 0, color = (255,255,255)):
         | 
| 45 | 
             
                width, height = pil_img.size
         | 
| 46 | 
             
                new_width = width + right + left
         | 
| 47 | 
             
                new_height = height + top + bottom
         | 
| 48 | 
             
                result = Image.new(pil_img.mode, (new_width, new_height), color)
         | 
| 49 | 
            +
             | 
| 50 | 
             
                result.paste(pil_img, (left, top))
         | 
| 51 | 
             
                return result
         | 
| 52 |  | 
| 53 | 
            +
            def image_grid(imgs, rows = 1, cols = None,
         | 
| 54 | 
             
                                size = None,
         | 
| 55 | 
             
                               titles = None, text_pos = (0, 0)):
         | 
| 56 | 
             
                if type(imgs) == list and type(imgs[0]) == torch.Tensor:
         | 
| 57 | 
             
                    imgs = torch.cat(imgs)
         | 
| 58 | 
             
                if type(imgs) == torch.Tensor:
         | 
| 59 | 
             
                    imgs = tensor_to_pil(imgs)
         | 
| 60 | 
            +
             | 
| 61 | 
             
                if not size is None:
         | 
| 62 | 
             
                    imgs = [img.resize((size,size)) for img in imgs]
         | 
| 63 | 
             
                if cols is None:
         | 
| 64 | 
             
                    cols = len(imgs)
         | 
| 65 | 
             
                assert len(imgs) >= rows*cols
         | 
| 66 | 
            +
             | 
| 67 | 
             
                top=20
         | 
| 68 | 
             
                w, h = imgs[0].size
         | 
| 69 | 
             
                delta = 0
         | 
|  | |
| 71 | 
             
                    delta = top
         | 
| 72 | 
             
                    h = imgs[1].size[1]
         | 
| 73 | 
             
                if not titles is  None:
         | 
| 74 | 
            +
                    font = ImageFont.truetype("/usr/share/fonts/truetype/freefont/FreeMono.ttf",
         | 
| 75 | 
             
                                                size = 20, encoding="unic")
         | 
| 76 | 
            +
                    h = top + h
         | 
| 77 | 
            +
                grid = Image.new('RGB', size=(cols*w, rows*h+delta))
         | 
| 78 | 
             
                for i, img in enumerate(imgs):
         | 
| 79 | 
            +
             | 
| 80 | 
             
                    if not titles is  None:
         | 
| 81 | 
             
                        img = add_margin(img, top = top, bottom = 0,left=0)
         | 
| 82 | 
             
                        draw = ImageDraw.Draw(img)
         | 
| 83 | 
            +
                        draw.text(text_pos, titles[i],(0,0,0),
         | 
| 84 | 
             
                        font = font)
         | 
| 85 | 
             
                    if not delta == 0 and i > 0:
         | 
| 86 | 
             
                       grid.paste(img, box=(i%cols*w, i//cols*h+delta))
         | 
| 87 | 
             
                    else:
         | 
| 88 | 
             
                        grid.paste(img, box=(i%cols*w, i//cols*h))
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                return grid
         | 
| 91 |  | 
| 92 |  | 
| 93 | 
             
            """
         | 
