File size: 3,918 Bytes
61554bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7eaec10
61554bf
 
7eaec10
 
 
 
 
 
 
 
61554bf
 
 
 
7eaec10
 
61554bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7eaec10
 
 
61554bf
 
 
 
 
7eaec10
 
 
 
 
61554bf
 
 
 
2bff9bb
61554bf
 
c63008c
61554bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import gradio as gr
import numpy as np
from pathlib import Path
from colour.io.luts.iridas_cube import read_LUT_IridasCube
import torch
import os

def get_available_luts():
    cube_luts_dir = Path('cube_luts')
    if not cube_luts_dir.exists():
        return []
    return sorted([f.name for f in cube_luts_dir.glob('*.cube')])

def apply_lut(image, lut_name, gamma_correction=True, clip_values=True, strength=1.0):
    if image is None or lut_name is None:
        return None
        
    # Convert gradio image to torch tensor
    image = torch.from_numpy(image).float() / 255.0
    
    # Get full path to LUT file
    lut_file = Path('cube_luts') / lut_name
    
    # Read LUT file with error handling for different encodings
    try:
        lut = read_LUT_IridasCube(str(lut_file))
    except UnicodeDecodeError:
        # Try different encodings if utf-8 fails
        try:
            with open(str(lut_file), 'r', encoding='latin-1') as f:
                lut = read_LUT_IridasCube(f)
        except Exception as e:
            print(f"Error reading LUT file with latin-1 encoding: {e}")
            return image.numpy() * 255.0
    except Exception as e:
        print(f"Error reading LUT file: {e}")
        return image.numpy() * 255.0

    lut.name = lut_name

    # Handle clipping
    if clip_values:
        if lut.domain[0].max() == lut.domain[0].min() and lut.domain[1].max() == lut.domain[1].min():
            lut.table = np.clip(lut.table, lut.domain[0, 0], lut.domain[1, 0])
        else:
            if len(lut.table.shape) == 2:  # 3x1D
                for dim in range(3):
                    lut.table[:, dim] = np.clip(lut.table[:, dim], lut.domain[0, dim], lut.domain[1, dim])
            else:  # 3D
                for dim in range(3):
                    lut.table[:, :, :, dim] = np.clip(lut.table[:, :, :, dim], lut.domain[0, dim], lut.domain[1, dim])

    # Process image
    lut_img = image.numpy().copy()
    
    is_non_default_domain = not np.array_equal(lut.domain, np.array([[0., 0., 0.], [1., 1., 1.]]))
    dom_scale = None
    if is_non_default_domain:
        dom_scale = lut.domain[1] - lut.domain[0]
        lut_img = lut_img * dom_scale + lut.domain[0]
    
    if gamma_correction:
        lut_img = lut_img ** (1/2.2)
    
    lut_img = lut.apply(lut_img)
    
    if gamma_correction:
        lut_img = lut_img ** (2.2)
    
    if is_non_default_domain:
        lut_img = (lut_img - lut.domain[0]) / dom_scale

    # Ensure values are in valid range
    lut_img = np.clip(lut_img, 0, 1)
    
    lut_img = torch.from_numpy(lut_img).float()
    
    if strength < 1.0:
        lut_img = strength * lut_img + (1 - strength) * image

    # Convert back to uint8 range and ensure proper bounds
    result = (lut_img.numpy() * 255.0)
    result = np.clip(result, 0, 255).astype(np.uint8)
    
    return result

def create_lut_tab():
    available_luts = get_available_luts()
    
    with gr.Tab("LUT"):
        with gr.Row():
            with gr.Column():
                input_image = gr.Image(label="Input Image", height=256)
                lut_dropdown = gr.Dropdown(
                    choices=available_luts,
                    label="Select LUT",
                    value=available_luts[0] if available_luts else None
                )
                gamma_correction = gr.Checkbox(label="Gamma Correction", value=True)
                clip_values = gr.Checkbox(label="Clip Values", value=True)
                strength = gr.Slider(minimum=0.0, maximum=1.0, value=1.0, step=0.1, label="Effect Strength")
                process_btn = gr.Button("Apply LUT")
            
            with gr.Column():
                output_image = gr.Image(label="Output Image")
        
        process_btn.click(
            fn=apply_lut,
            inputs=[input_image, lut_dropdown, gamma_correction, clip_values, strength],
            outputs=output_image
        )