File size: 5,328 Bytes
3dc5131 31fff2a 3dc5131 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
import cv2
import torch
import numpy as np
import matplotlib.pyplot as plt
from .model import ConditionalUNet
from huggingface_hub import hf_hub_download
def load_models(ENV,device):
if ENV=="DEPLOY":
model_path = hf_hub_download(repo_id="CristianLazoQuispe/MNIST_Diff_Flow_matching", filename="outputs/diffusion/diffusion_model.pth",cache_dir="models")
else:
model_path = "outputs/diffusion/diffusion_model.pth"
print("Diff Downloaded!")
model_diff_standard = ConditionalUNet().to(device)
model_diff_standard.load_state_dict(torch.load(model_path, map_location=device))
model_diff_standard.eval()
if ENV=="DEPLOY":
model_path_standard = hf_hub_download(repo_id="CristianLazoQuispe/MNIST_Diff_Flow_matching", filename="outputs/flow_matching/flow_model.pth",cache_dir="models")
model_path_localized = hf_hub_download(repo_id="CristianLazoQuispe/MNIST_Diff_Flow_matching", filename="outputs/flow_matching/flow_model_localized_noise.pth",cache_dir="models")
else:
model_path_standard = "outputs/flow_matching/flow_model.pth"
model_path_localized = "outputs/flow_matching/flow_model_localized_noise.pth"
print("Flow Downloaded!")
model_flow_standard = ConditionalUNet().to(device)
model_flow_standard.load_state_dict(torch.load(model_path_standard, map_location=device))
model_flow_standard.eval()
model_flow_localized = ConditionalUNet().to(device)
model_flow_localized.load_state_dict(torch.load(model_path_localized, map_location=device))
model_flow_localized.eval()
return model_diff_standard,model_flow_standard,model_flow_localized
def load_model_diff(ENV,device):
if ENV=="DEPLOY":
model_path = hf_hub_download(repo_id="CristianLazoQuispe/MNIST_Diff_Flow_matching", filename="outputs/diffusion/diffusion_model.pth",cache_dir="models")
else:
model_path = "outputs/diffusion/diffusion_model.pth"
print("Diff Downloaded!")
model_diff_standard = ConditionalUNet().to(device)
model_diff_standard.load_state_dict(torch.load(model_path, map_location=device))
model_diff_standard.eval()
return model_diff_standard
def load_model_flow_standard(ENV,device):
if ENV=="DEPLOY":
model_path_standard = hf_hub_download(repo_id="CristianLazoQuispe/MNIST_Diff_Flow_matching", filename="outputs/flow_matching/flow_model.pth",cache_dir="models")
else:
model_path_standard = "outputs/flow_matching/flow_model.pth"
print("Flow Downloaded!")
model_flow_standard = ConditionalUNet().to(device)
model_flow_standard.load_state_dict(torch.load(model_path_standard, map_location=device))
model_flow_standard.eval()
return model_flow_standard
def load_model_flow_localized(ENV,device):
if ENV=="DEPLOY":
model_path_localized = hf_hub_download(repo_id="CristianLazoQuispe/MNIST_Diff_Flow_matching", filename="outputs/flow_matching/flow_model_localized_noise.pth",cache_dir="models")
else:
model_path_localized = "outputs/flow_matching/flow_model_localized_noise.pth"
print("Flow Downloaded!")
model_flow_localized = ConditionalUNet().to(device)
model_flow_localized.load_state_dict(torch.load(model_path_localized, map_location=device))
model_flow_localized.eval()
return model_flow_localized
def resize(image,size=(200,200)):
stretch_near = cv2.resize(image, size, interpolation = cv2.INTER_LINEAR)
return stretch_near
def plot_diff(outputs,x,t,noise_pred):
if t in [499, 399, 299, 199, 99, 0]:
step_idx = {499: 6, 399: 7, 299: 8, 199: 9, 99: 10, 0: 11}[t]
v_mag = noise_pred[0, 0].abs().clamp(0, 3).cpu().numpy()
v_mag = (v_mag - v_mag.min()) / (v_mag.max() - v_mag.min() + 1e-5)
vel_colored = plt.get_cmap("coolwarm")(v_mag)[:, :, :3]
vel_colored = (vel_colored * 255).astype(np.uint8)
outputs[step_idx] = resize(vel_colored)
outputs[12] = resize(((x + 1) / 2.0)[0, 0].cpu().numpy(),(300,300))
if t in [400, 300, 200, 100, 1, 0]:
step_idx = {400: 1, 300: 2, 200: 3, 100: 4, 1: 5, 0 :12}[t]
if t==0:
outputs[step_idx] = resize(((x + 1) / 2.0)[0, 0].cpu().numpy(),(300,300))
else:
outputs[step_idx] = resize(((x + 1) / 2.0)[0, 0].cpu().numpy())
return outputs
def plot_flow(outputs,i,x,dt,v):
# Compute velocity magnitude and convert to numpy for visualization
outputs[12] = resize(((x + 1) / 2.0)[0, 0].clamp(0, 1).cpu().numpy(),(300,300))
if i in [10,20,30,40,48,49]: #
step_idx = {10: 1, 20: 2, 30: 3, 40: 4, 48: 5,49:12}[i] #,
if i==49:
outputs[step_idx] = resize(((x + 1) / 2.0)[0, 0].clamp(0, 1).cpu().numpy(),(300,300))
else:
outputs[step_idx] = resize(((x + 1) / 2.0)[0, 0].clamp(0, 1).cpu().numpy())
if i in [0,11,21,31,41,49]:
v_mag = dt*v[0, 0].abs().clamp(0, 3).cpu().numpy() # Clamp to max value for better contrast
v_mag = (v_mag - v_mag.min()) / (v_mag.max() - v_mag.min() + 1e-5)
vel_colored = plt.get_cmap("coolwarm")(v_mag)[:, :, :3] # (H,W,3)
vel_colored = (vel_colored * 255).astype(np.uint8)
step_idx = {0: 6, 11: 7, 21: 8, 31: 9, 41: 10, 49:11}[i]
outputs[step_idx] = resize(vel_colored)
return outputs |