|
import os |
|
from io import BytesIO |
|
|
|
import gradio as gr |
|
import grpc |
|
from PIL import Image |
|
import pandas as pd |
|
|
|
import numpy as np |
|
import time |
|
|
|
from io import BytesIO |
|
from inference_pb2 import LoraRequest, LoraResponse |
|
from inference_pb2_grpc import LoraServiceStub |
|
import grpc |
|
|
|
|
|
|
|
PREFIX = "./" |
|
|
|
info = { |
|
'image': PREFIX + 'preview/{0}.jpg', |
|
'weights_path': PREFIX + 'demo_results/flux-lora-{0}_aug-rank16', |
|
'caption': PREFIX + 'demo/{0}_aug/data.csv', |
|
'aug_path': PREFIX + 'demo/{0}_aug_filter/' |
|
} |
|
|
|
params = { |
|
'cup': {'switch_t' : 7, 'aug_image' : None, 'checkpoint' : 1000}, |
|
'face_lifting': {'switch_t' : 7, 'aug_image' : None, 'checkpoint' : 1000}, |
|
'coffe_machine': {'switch_t' : 7, 'aug_image' : None, 'checkpoint' : 1000}, |
|
'kettle': {'switch_t' : 3, 'aug_image' : None, 'checkpoint' : 1000}, |
|
'body_lotion': {'switch_t' : 7, 'aug_image' : None, 'checkpoint' : 1000}, |
|
'toy': {'switch_t' : 3, 'aug_image' : None, 'checkpoint' : 1000}, |
|
'bag': {'switch_t' : 3, 'aug_image' : None, 'checkpoint' : 1000}, |
|
'armchair': {'switch_t' : 3, 'aug_image' : None, 'checkpoint' : 600}, |
|
'pendant': {'switch_t' : -1, 'aug_image' : None, 'checkpoint' : 1000}, |
|
'car': {'switch_t' : 7, 'aug_image' : 'car_aug_2.jpg', 'checkpoint' : 600}, |
|
} |
|
|
|
table = pd.read_csv("./data.csv") |
|
CAPTIONS = {} |
|
for line in table.values: |
|
print(line) |
|
CAPTIONS[line[0]] = line[1] |
|
|
|
|
|
def bytes_to_image(image: bytes) -> Image.Image: |
|
image = Image.open(BytesIO(image)) |
|
return image |
|
|
|
|
|
def generate_image(concept, prompt, progress=gr.Progress(track_tqdm=True)): |
|
with grpc.insecure_channel(os.environ["SERVER"]) as channel: |
|
stub = LoraServiceStub(channel) |
|
|
|
output = stub.generate( |
|
LoraRequest(prompt=prompt, concept=concept, use_cache=False) |
|
) |
|
|
|
return gr.update(value=bytes_to_image(output.res1)), gr.update(), gr.update(), gr.update() |
|
|
|
|
|
temaplte = """ |
|
<div style="font-size: 18px;"> |
|
<b>Product description:</b> {} |
|
</div> |
|
""" |
|
|
|
|
|
def action1(): |
|
concept = "kettle" |
|
img = Image.open(info["image"].format(concept)) |
|
description = temaplte.format(CAPTIONS[concept]) |
|
return gr.update(value=img, visible=True), gr.update(value=description, visible=True), gr.update(visible=False), gr.update(visible=True, placeholder="in a cozy kitchen"), gr.update(visible=True), gr.update(value=concept) |
|
|
|
def action2(): |
|
concept = "face_lifting" |
|
img = Image.open(info["image"].format(concept)) |
|
description = temaplte.format(CAPTIONS[concept]) |
|
return gr.update(value=img, visible=True), gr.update(value=description, visible=True), gr.update(visible=False), gr.update(visible=True, placeholder="in a sunny bathroom with green plants"), gr.update(visible=True), gr.update(value=concept) |
|
|
|
def action3(): |
|
concept = "pendant" |
|
img = Image.open(info["image"].format(concept)) |
|
description = temaplte.format(CAPTIONS[concept]) |
|
return gr.update(value=img, visible=True), gr.update(value=description, visible=True), gr.update(visible=False), gr.update(visible=True, placeholder="on a beautiful blonde woman"), gr.update(visible=True), gr.update(value=concept) |
|
|
|
def action4(): |
|
concept = "car" |
|
img = Image.open(info["image"].format(concept)) |
|
description = temaplte.format(CAPTIONS[concept]) |
|
return gr.update(value=img, visible=True), gr.update(value=description, visible=True), gr.update(visible=False), gr.update(visible=True, placeholder="driving in a desert"), gr.update(visible=True), gr.update(value=concept) |
|
|
|
def action5(): |
|
concept = "body_lotion" |
|
img = Image.open(info["image"].format(concept)) |
|
description = temaplte.format(CAPTIONS[concept]) |
|
return gr.update(value=img, visible=True), gr.update(value=description, visible=True), gr.update(visible=False), gr.update(visible=True, placeholder="in a cozy bathroom"), gr.update(visible=True), gr.update(value=concept) |
|
|
|
def action6(): |
|
concept = "toy" |
|
img = Image.open(info["image"].format(concept)) |
|
description = temaplte.format(CAPTIONS[concept]) |
|
return gr.update(value=img, visible=True), gr.update(value=description, visible=True), gr.update(visible=False), gr.update(visible=True, placeholder="in a cozy living room"), gr.update(visible=True), gr.update(value=concept) |
|
|
|
def action7(): |
|
concept = "bag" |
|
img = Image.open(info["image"].format(concept)) |
|
description = temaplte.format(CAPTIONS[concept]) |
|
return gr.update(value=img, visible=True), gr.update(value=description, visible=True), gr.update(visible=False), gr.update(visible=True, placeholder="in a cozy living room"), gr.update(visible=True), gr.update(value=concept) |
|
|
|
def action8(): |
|
concept = "armchair" |
|
img = Image.open(info["image"].format(concept)) |
|
description = temaplte.format(CAPTIONS[concept]) |
|
return gr.update(value=img, visible=True), gr.update(value=description, visible=True), gr.update(visible=False), gr.update(visible=True, placeholder="in a cozy living room"), gr.update(visible=True), gr.update(value=concept) |
|
|
|
|
|
css2 = """ |
|
|
|
.my-custom-button { |
|
width: 100px; /* Button size */ |
|
height: 130px; |
|
padding: 0; /* Remove default padding */ |
|
margin: 0px; /* Optional spacing between buttons */ |
|
display: flex; |
|
align-items: center; |
|
justify-content: center; |
|
background-color: transparent; |
|
border: none; |
|
overflow: hidden; /* Ensures the image doesn't overflow */ |
|
--text-xl: 150px |
|
} |
|
|
|
.my-custom-button img { |
|
max-width: 100%; |
|
max-height: 100%; |
|
object-fit: contain; /* Ensure icon scales properly */ |
|
} |
|
|
|
.input_image_container { |
|
width: 350px !important; |
|
height: 350px !important; |
|
overflow: hidden; |
|
display: flex; |
|
align-items: center; |
|
justify-content: center; |
|
background-color: #f0f0f0; |
|
} |
|
.input_image_container img { |
|
max-width: 100%; |
|
max-height: 100%; |
|
width: 350px; |
|
height: 350px; |
|
object-fit: contain; |
|
display: block; |
|
margin: 0 auto; |
|
} |
|
|
|
.prompt input { |
|
font-size: 20px; |
|
} |
|
|
|
.prompt input::placeholder { |
|
font-size: 20px; |
|
} |
|
|
|
.prompt label { |
|
font-size: 20px !important; |
|
} |
|
|
|
.hint { |
|
font-size: 20px !important; |
|
text-align: center; |
|
--text-md: 20px; |
|
} |
|
|
|
.airi-content { |
|
display: inline-flex; |
|
align-items: center; |
|
gap: 10px; |
|
white-space: nowrap; |
|
} |
|
.airi { |
|
font-size: 20px; |
|
text-align: center; |
|
--text-md: 20px; |
|
width: 100%; |
|
padding: 10px 0; |
|
} |
|
.airi img { |
|
vertical-align: middle; |
|
max-height: 1.2em; |
|
margin-left: 5px; |
|
} |
|
""" |
|
|
|
def get_demo(): |
|
with gr.Blocks(css=""" |
|
.centered { |
|
display: flex; |
|
justify-content: center; |
|
align-items: center; |
|
height: 100%; |
|
} |
|
.centered img { |
|
margin: auto; |
|
object-fit: contain; |
|
} |
|
""" + css2) as demo: |
|
gr.Markdown("## Showcase Commercial Products with Stunning Natural Backgrounds") |
|
with gr.Row(): |
|
with gr.Column(elem_classes=["centered"]): |
|
with gr.Row(): |
|
btn1 = gr.Button("", icon=info["image"].format("kettle"), elem_classes=["my-custom-button"]) |
|
btn2 = gr.Button("", icon=info["image"].format("face_lifting"), elem_classes=["my-custom-button"]) |
|
btn3 = gr.Button("", icon=info["image"].format("pendant"), elem_classes=["my-custom-button"]) |
|
btn4 = gr.Button("", icon=info["image"].format("car"), elem_classes=["my-custom-button"]) |
|
|
|
with gr.Row(): |
|
btn5 = gr.Button("", icon=info["image"].format("body_lotion"), elem_classes=["my-custom-button"]) |
|
btn6 = gr.Button("", icon=info["image"].format("toy"), elem_classes=["my-custom-button"]) |
|
btn7 = gr.Button("", icon=info["image"].format("bag"), elem_classes=["my-custom-button"]) |
|
btn8 = gr.Button("", icon=info["image"].format("armchair"), elem_classes=["my-custom-button"]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prod_desc = gr.Markdown("Choose the product you want to showcase ๐ ", elem_classes=["hint"]) |
|
|
|
input_image = gr.Image(label="Chosen product", type="pil", height=300, width=300, visible=False, interactive=False, container=True, elem_classes=["input_image_container"]) |
|
descr = gr.Markdown(value=temaplte.format(""), visible=False) |
|
|
|
concept = gr.Textbox("", visible=False) |
|
prompt = gr.Textbox("", placeholder="is in the cozy kitchen", label="Describe the enviroment for your product", submit_btn=False, max_lines=1, visible=False, elem_classes=["prompt"]) |
|
|
|
btn_generate = gr.Button("Generate images", visible=False) |
|
|
|
|
|
btn1.click(fn=action1, outputs=[input_image, descr, prod_desc, prompt, btn_generate, concept]) |
|
btn2.click(fn=action2, outputs=[input_image, descr, prod_desc, prompt, btn_generate, concept]) |
|
btn3.click(fn=action3, outputs=[input_image, descr, prod_desc, prompt, btn_generate, concept]) |
|
btn4.click(fn=action4, outputs=[input_image, descr, prod_desc, prompt, btn_generate, concept]) |
|
btn5.click(fn=action5, outputs=[input_image, descr, prod_desc, prompt, btn_generate, concept]) |
|
btn6.click(fn=action6, outputs=[input_image, descr, prod_desc, prompt, btn_generate, concept]) |
|
btn7.click(fn=action7, outputs=[input_image, descr, prod_desc, prompt, btn_generate, concept]) |
|
btn8.click(fn=action8, outputs=[input_image, descr, prod_desc, prompt, btn_generate, concept]) |
|
|
|
with gr.Column(): |
|
with gr.Row(): |
|
res1 = gr.Image(label="Result", visible=True) |
|
res2 = gr.Image(label="Result 2", visible=False) |
|
with gr.Row(): |
|
res3 = gr.Image(label="Result 2", visible=False, height=450, width=450) |
|
res4 = gr.Image(label="Result 4", visible=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gr.Markdown( |
|
'<span class="airi-content">Made by FusionBrainLab, AIRI <img src="https://static.tildacdn.com/tild3633-6662-4437-a333-646631346335/Airinet.png" style="width: 70px; height: auto;"></span>', |
|
elem_classes=["airi"] |
|
) |
|
|
|
btn_generate.click( |
|
fn=generate_image, |
|
inputs=[concept, prompt], |
|
outputs=[res1, res2, res3, res4] |
|
) |
|
|
|
return demo |
|
|
|
|
|
if __name__ == '__main__': |
|
demo = get_demo() |
|
demo.launch(server_name="0.0.0.0", server_port=7860) |
|
|