Itcast's picture
add some diff
c8cc003
raw
history blame
1.19 kB
import os
os.system("pip install tensorflow==2.3.0")
os.system("pip install tensorflow_hub")
os.system("pip install numpy==1.16.4")
import tensorflow as tf
# Load compressed models from tensorflow_hub
os.environ['TFHUB_MODEL_LOAD_FORMAT'] = 'COMPRESSED'
import numpy as np
import PIL.Image
import time
import functools
def tensor_to_image(tensor):
tensor = tensor*255
tensor = np.array(tensor, dtype=np.uint8)
if np.ndim(tensor)>3:
assert tensor.shape[0] == 1
tensor = tensor[0]
return PIL.Image.fromarray(tensor)
import tensorflow_hub as hub
hub_model = hub.load('https://tfhub.dev/google/magenta/arbitrary-image-stylization-v1-256/2')
import gradio as gr
def inference(content_image, style_image):
stylized_image = hub_model(tf.constant(content_image), tf.constant(style_image))[0]
img = tensor_to_image(stylized_image)
return img
title = "TTT"
gr.Interface(
inference,
gr.inputs.Image(type="pil", label="content_image"),
gr.inputs.Image(type="pil", label="style_image"),
gr.outputs.Image(type="pil", label="Output"),
title=title,
description="",
enable_queue=True,
allow_flagging=False
).launch()