File size: 2,307 Bytes
5b68e3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f79bb39
 
 
5b68e3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#!/usr/bin/env python

from __future__ import annotations

import argparse
import functools
import os
import pickle
import sys
import subprocess

import gradio as gr
import numpy as np
import torch
import torch.nn as nn

sys.path.append('.')
sys.path.append('./Time_TravelRephotography')
from utils import torch_helpers as th
from argparse import Namespace
from projector import (
    ProjectorArguments,
    main,
    create_generator,
    make_image,
)

input_path = ''
spectral_sensitivity =  'b'
TITLE = 'Time-TravelRephotography'
DESCRIPTION = '''This is an unofficial demo for https://github.com/Time-Travel-Rephotography.
'''
ARTICLE = '<center><img src="https://visitor-badge.glitch.me/badge?page_id=Time-TravelRephotography" alt="visitor badge"/></center>'

   
def image_create(seed: int, truncation_psi: float):
    args = ProjectorArguments().parse(
        args=[str(input_path)],
        namespace=Namespace(
            encoder_ckpt=f"checkpoint/encoder/checkpoint_{spectral_sensitivity}.pt",
            #gaussian=gaussian_radius,
            log_visual_freq=1000
    ))
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model.to(device)
    data = data.to(device)
    device = th.device()
    generator = create_generator("stylegan2-ffhq-config-f.pt","feng2022/Time-TravelRephotography_stylegan2-ffhq-config-f",args,device)
    #generator = create_generator("checkpoint_b.pt.pth","feng2022/Time_TravelRephotography_checkpoint_b",args,device)
    latent = torch.randn((1, 512), device=device) 
    img_out, _, _ = generator([latent])
    imgs_arr = make_image(img_out)
    return imgs_arr[0]/255
    
def main():
    torch.cuda.init()
    if torch.cuda.is_initialized():
        ini = "True1"
    else:
        ini = "False1"
    result = subprocess.check_output(['nvidia-smi'])
    device = th.device()
    iface = gr.Interface(
          image_create,
          [
                gr.inputs.Number(default=0, label='Seed'),
                gr.inputs.Slider(
                    0, 2, step=0.05, default=0.7, label='Truncation psi'),
          ],
          gr.outputs.Image(type='numpy', label='Output'),
          title=TITLE,
          description=DESCRIPTION,
          article=ARTICLE,
          )
    
    iface.launch()
        
if __name__ == '__main__':
    main()