|
import cv2 |
|
import torch |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
from .model import ConditionalUNet |
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
def load_models(ENV,device): |
|
if ENV=="DEPLOY": |
|
model_path = hf_hub_download(repo_id="CristianLazoQuispe/MNIST_Diff_Flow_matching", filename="outputs/diffusion/diffusion_model.pth",cache_dir="models") |
|
else: |
|
model_path = "outputs/diffusion/diffusion_model.pth" |
|
print("Diff Downloaded!") |
|
model_diff_standard = ConditionalUNet().to(device) |
|
model_diff_standard.load_state_dict(torch.load(model_path, map_location=device)) |
|
model_diff_standard.eval() |
|
|
|
if ENV=="DEPLOY": |
|
model_path_standard = hf_hub_download(repo_id="CristianLazoQuispe/MNIST_Diff_Flow_matching", filename="outputs/flow_matching/flow_model.pth",cache_dir="models") |
|
model_path_localized = hf_hub_download(repo_id="CristianLazoQuispe/MNIST_Diff_Flow_matching", filename="outputs/flow_matching/flow_model_localized_noise.pth",cache_dir="models") |
|
else: |
|
model_path_standard = "outputs/flow_matching/flow_model.pth" |
|
model_path_localized = "outputs/flow_matching/flow_model_localized_noise.pth" |
|
print("Flow Downloaded!") |
|
model_flow_standard = ConditionalUNet().to(device) |
|
model_flow_standard.load_state_dict(torch.load(model_path_standard, map_location=device)) |
|
model_flow_standard.eval() |
|
model_flow_localized = ConditionalUNet().to(device) |
|
model_flow_localized.load_state_dict(torch.load(model_path_localized, map_location=device)) |
|
model_flow_localized.eval() |
|
|
|
return model_diff_standard,model_flow_standard,model_flow_localized |
|
|
|
|
|
def load_model_diff(ENV,device): |
|
if ENV=="DEPLOY": |
|
model_path = hf_hub_download(repo_id="CristianLazoQuispe/MNIST_Diff_Flow_matching", filename="outputs/diffusion/diffusion_model.pth",cache_dir="models") |
|
else: |
|
model_path = "outputs/diffusion/diffusion_model.pth" |
|
print("Diff Downloaded!") |
|
model_diff_standard = ConditionalUNet().to(device) |
|
model_diff_standard.load_state_dict(torch.load(model_path, map_location=device)) |
|
model_diff_standard.eval() |
|
return model_diff_standard |
|
|
|
def load_model_flow_standard(ENV,device): |
|
if ENV=="DEPLOY": |
|
model_path_standard = hf_hub_download(repo_id="CristianLazoQuispe/MNIST_Diff_Flow_matching", filename="outputs/flow_matching/flow_model.pth",cache_dir="models") |
|
else: |
|
model_path_standard = "outputs/flow_matching/flow_model.pth" |
|
print("Flow Downloaded!") |
|
model_flow_standard = ConditionalUNet().to(device) |
|
model_flow_standard.load_state_dict(torch.load(model_path_standard, map_location=device)) |
|
model_flow_standard.eval() |
|
return model_flow_standard |
|
|
|
def load_model_flow_localized(ENV,device): |
|
if ENV=="DEPLOY": |
|
model_path_localized = hf_hub_download(repo_id="CristianLazoQuispe/MNIST_Diff_Flow_matching", filename="outputs/flow_matching/flow_model_localized_noise.pth",cache_dir="models") |
|
else: |
|
model_path_localized = "outputs/flow_matching/flow_model_localized_noise.pth" |
|
print("Flow Downloaded!") |
|
model_flow_localized = ConditionalUNet().to(device) |
|
model_flow_localized.load_state_dict(torch.load(model_path_localized, map_location=device)) |
|
model_flow_localized.eval() |
|
return model_flow_localized |
|
|
|
|
|
def resize(image,size=(200,200)): |
|
stretch_near = cv2.resize(image, size, interpolation = cv2.INTER_LINEAR) |
|
return stretch_near |
|
|
|
|
|
|
|
def plot_diff(outputs,x,t,noise_pred): |
|
|
|
if t in [499, 399, 299, 199, 99, 0]: |
|
step_idx = {499: 6, 399: 7, 299: 8, 199: 9, 99: 10, 0: 11}[t] |
|
v_mag = noise_pred[0, 0].abs().clamp(0, 3).cpu().numpy() |
|
v_mag = (v_mag - v_mag.min()) / (v_mag.max() - v_mag.min() + 1e-5) |
|
vel_colored = plt.get_cmap("coolwarm")(v_mag)[:, :, :3] |
|
vel_colored = (vel_colored * 255).astype(np.uint8) |
|
outputs[step_idx] = resize(vel_colored) |
|
|
|
outputs[12] = resize(((x + 1) / 2.0)[0, 0].cpu().numpy(),(300,300)) |
|
|
|
if t in [400, 300, 200, 100, 1, 0]: |
|
step_idx = {400: 1, 300: 2, 200: 3, 100: 4, 1: 5, 0 :12}[t] |
|
if t==0: |
|
outputs[step_idx] = resize(((x + 1) / 2.0)[0, 0].cpu().numpy(),(300,300)) |
|
else: |
|
outputs[step_idx] = resize(((x + 1) / 2.0)[0, 0].cpu().numpy()) |
|
|
|
return outputs |
|
|
|
def plot_flow(outputs,i,x,dt,v): |
|
|
|
outputs[12] = resize(((x + 1) / 2.0)[0, 0].clamp(0, 1).cpu().numpy(),(300,300)) |
|
if i in [10,20,30,40,48,49]: |
|
step_idx = {10: 1, 20: 2, 30: 3, 40: 4, 48: 5,49:12}[i] |
|
if i==49: |
|
outputs[step_idx] = resize(((x + 1) / 2.0)[0, 0].clamp(0, 1).cpu().numpy(),(300,300)) |
|
else: |
|
outputs[step_idx] = resize(((x + 1) / 2.0)[0, 0].clamp(0, 1).cpu().numpy()) |
|
|
|
if i in [0,11,21,31,41,49]: |
|
v_mag = dt*v[0, 0].abs().clamp(0, 3).cpu().numpy() |
|
v_mag = (v_mag - v_mag.min()) / (v_mag.max() - v_mag.min() + 1e-5) |
|
vel_colored = plt.get_cmap("coolwarm")(v_mag)[:, :, :3] |
|
vel_colored = (vel_colored * 255).astype(np.uint8) |
|
step_idx = {0: 6, 11: 7, 21: 8, 31: 9, 41: 10, 49:11}[i] |
|
outputs[step_idx] = resize(vel_colored) |
|
return outputs |