File size: 5,328 Bytes
3dc5131
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31fff2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3dc5131
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import cv2
import torch
import numpy as np
import matplotlib.pyplot as plt
from .model import ConditionalUNet
from huggingface_hub import hf_hub_download


def load_models(ENV,device):
    if ENV=="DEPLOY":
        model_path = hf_hub_download(repo_id="CristianLazoQuispe/MNIST_Diff_Flow_matching", filename="outputs/diffusion/diffusion_model.pth",cache_dir="models")
    else:
        model_path  = "outputs/diffusion/diffusion_model.pth"
    print("Diff Downloaded!")
    model_diff_standard  = ConditionalUNet().to(device)
    model_diff_standard.load_state_dict(torch.load(model_path, map_location=device))
    model_diff_standard.eval()

    if ENV=="DEPLOY":
        model_path_standard  = hf_hub_download(repo_id="CristianLazoQuispe/MNIST_Diff_Flow_matching", filename="outputs/flow_matching/flow_model.pth",cache_dir="models")
        model_path_localized = hf_hub_download(repo_id="CristianLazoQuispe/MNIST_Diff_Flow_matching", filename="outputs/flow_matching/flow_model_localized_noise.pth",cache_dir="models")
    else:
        model_path_standard  = "outputs/flow_matching/flow_model.pth"
        model_path_localized = "outputs/flow_matching/flow_model_localized_noise.pth"
    print("Flow Downloaded!")
    model_flow_standard  = ConditionalUNet().to(device)
    model_flow_standard.load_state_dict(torch.load(model_path_standard, map_location=device))
    model_flow_standard.eval()
    model_flow_localized = ConditionalUNet().to(device)
    model_flow_localized.load_state_dict(torch.load(model_path_localized, map_location=device))
    model_flow_localized.eval()

    return model_diff_standard,model_flow_standard,model_flow_localized


def load_model_diff(ENV,device):
    if ENV=="DEPLOY":
        model_path = hf_hub_download(repo_id="CristianLazoQuispe/MNIST_Diff_Flow_matching", filename="outputs/diffusion/diffusion_model.pth",cache_dir="models")
    else:
        model_path  = "outputs/diffusion/diffusion_model.pth"
    print("Diff Downloaded!")
    model_diff_standard  = ConditionalUNet().to(device)
    model_diff_standard.load_state_dict(torch.load(model_path, map_location=device))
    model_diff_standard.eval()
    return model_diff_standard

def load_model_flow_standard(ENV,device):
    if ENV=="DEPLOY":
        model_path_standard  = hf_hub_download(repo_id="CristianLazoQuispe/MNIST_Diff_Flow_matching", filename="outputs/flow_matching/flow_model.pth",cache_dir="models")
    else:
        model_path_standard  = "outputs/flow_matching/flow_model.pth"
    print("Flow Downloaded!")
    model_flow_standard  = ConditionalUNet().to(device)
    model_flow_standard.load_state_dict(torch.load(model_path_standard, map_location=device))
    model_flow_standard.eval()
    return model_flow_standard

def load_model_flow_localized(ENV,device):
    if ENV=="DEPLOY":
        model_path_localized = hf_hub_download(repo_id="CristianLazoQuispe/MNIST_Diff_Flow_matching", filename="outputs/flow_matching/flow_model_localized_noise.pth",cache_dir="models")
    else:
        model_path_localized = "outputs/flow_matching/flow_model_localized_noise.pth"
    print("Flow Downloaded!")
    model_flow_localized = ConditionalUNet().to(device)
    model_flow_localized.load_state_dict(torch.load(model_path_localized, map_location=device))
    model_flow_localized.eval()
    return model_flow_localized


def resize(image,size=(200,200)):
    stretch_near = cv2.resize(image, size, interpolation = cv2.INTER_LINEAR)
    return stretch_near
        


def plot_diff(outputs,x,t,noise_pred):

    if t in [499, 399, 299, 199, 99, 0]:
        step_idx = {499: 6, 399: 7, 299: 8, 199: 9, 99: 10, 0: 11}[t]
        v_mag = noise_pred[0, 0].abs().clamp(0, 3).cpu().numpy()
        v_mag = (v_mag - v_mag.min()) / (v_mag.max() - v_mag.min() + 1e-5)
        vel_colored = plt.get_cmap("coolwarm")(v_mag)[:, :, :3]
        vel_colored = (vel_colored * 255).astype(np.uint8)
        outputs[step_idx] = resize(vel_colored)

    outputs[12] = resize(((x + 1) / 2.0)[0, 0].cpu().numpy(),(300,300))

    if t in [400, 300, 200, 100, 1, 0]:
        step_idx = {400: 1, 300: 2, 200: 3, 100: 4, 1: 5, 0 :12}[t]
        if t==0:
            outputs[step_idx] = resize(((x + 1) / 2.0)[0, 0].cpu().numpy(),(300,300))
        else:
            outputs[step_idx] = resize(((x + 1) / 2.0)[0, 0].cpu().numpy())
        
    return outputs

def plot_flow(outputs,i,x,dt,v):
    # Compute velocity magnitude and convert to numpy for visualization
    outputs[12] =  resize(((x + 1) / 2.0)[0, 0].clamp(0, 1).cpu().numpy(),(300,300))
    if i in [10,20,30,40,48,49]: #
        step_idx = {10: 1, 20: 2, 30: 3, 40: 4, 48: 5,49:12}[i] #, 
        if i==49:
            outputs[step_idx] = resize(((x + 1) / 2.0)[0, 0].clamp(0, 1).cpu().numpy(),(300,300))
        else:
            outputs[step_idx] = resize(((x + 1) / 2.0)[0, 0].clamp(0, 1).cpu().numpy())

    if i in [0,11,21,31,41,49]:
        v_mag = dt*v[0, 0].abs().clamp(0, 3).cpu().numpy()  # Clamp to max value for better contrast
        v_mag = (v_mag - v_mag.min()) / (v_mag.max() - v_mag.min() + 1e-5)
        vel_colored = plt.get_cmap("coolwarm")(v_mag)[:, :, :3]  # (H,W,3)
        vel_colored = (vel_colored * 255).astype(np.uint8)
        step_idx = {0: 6, 11: 7, 21: 8, 31: 9, 41: 10, 49:11}[i]
        outputs[step_idx] = resize(vel_colored)
    return outputs