Nina.Konovalova commited on
Commit
ef5bd5d
·
1 Parent(s): f5adbed
Files changed (3) hide show
  1. app.py +245 -0
  2. inference_pb2.py +40 -0
  3. inference_pb2_grpc.py +97 -0
app.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from io import BytesIO
3
+
4
+ import gradio as gr
5
+ import grpc
6
+ from PIL import Image
7
+ import pandas as pd
8
+
9
+ import numpy as np
10
+ import time
11
+
12
+ from io import BytesIO
13
+ from inference_pb2 import LoraRequest, LoraResponse
14
+ from inference_pb2_grpc import LoraServiceStub
15
+ import grpc
16
+
17
+
18
+
19
+ PREFIX = "/home/jovyan/bobkov/lora_demo/DEMO/"
20
+
21
+ info = {
22
+ 'image': PREFIX + 'demo/{0}/{0}.jpg',
23
+ 'weights_path': PREFIX + 'demo_results/flux-lora-{0}_aug-rank16',
24
+ 'caption': PREFIX + 'demo/{0}_aug/data.csv',
25
+ 'aug_path': PREFIX + 'demo/{0}_aug_filter/'
26
+ }
27
+
28
+ params = {
29
+ 'cup': {'switch_t' : 7, 'aug_image' : None, 'checkpoint' : 1000},
30
+ 'face_lifting': {'switch_t' : 7, 'aug_image' : None, 'checkpoint' : 1000},
31
+ 'coffe_machine': {'switch_t' : 7, 'aug_image' : None, 'checkpoint' : 1000},
32
+ 'kettle': {'switch_t' : 3, 'aug_image' : None, 'checkpoint' : 1000},
33
+ 'body_lotion': {'switch_t' : 7, 'aug_image' : None, 'checkpoint' : 1000},
34
+ 'toy': {'switch_t' : 3, 'aug_image' : None, 'checkpoint' : 1000},
35
+ 'bag': {'switch_t' : 3, 'aug_image' : None, 'checkpoint' : 1000},
36
+ 'armchair': {'switch_t' : 3, 'aug_image' : None, 'checkpoint' : 600},
37
+ 'pendant': {'switch_t' : -1, 'aug_image' : None, 'checkpoint' : 1000},
38
+ 'car': {'switch_t' : 7, 'aug_image' : 'car_aug_2.jpg', 'checkpoint' : 600},
39
+ }
40
+
41
+ table = pd.read_csv("/home/jovyan/bobkov/lora_demo/DEMO/demo/data.csv")
42
+ CAPTIONS = {}
43
+ for line in table.values:
44
+ CAPTIONS[line[0]] = line[1]
45
+
46
+
47
+ def bytes_to_image(image: bytes) -> Image.Image:
48
+ image = Image.open(BytesIO(image))
49
+ return image
50
+
51
+
52
+ def generate_image(concept, prompt, progress=gr.Progress(track_tqdm=True)):
53
+ with grpc.insecure_channel(os.environ["SERVER"]) as channel:
54
+ stub = LoraServiceStub(channel)
55
+
56
+ output = stub.generate(
57
+ LoraRequest(prompt=prompt, concept=concept, use_cache=False)
58
+ )
59
+
60
+ return gr.update(value=bytes_to_image(output.res1)), gr.update(), gr.update(), gr.update()
61
+
62
+
63
+ temaplte = """
64
+ <div style="font-size: 18px;">
65
+ <b>Product description:</b> {}
66
+ </div>
67
+ """
68
+
69
+
70
+ def action1():
71
+ concept = "kettle" #############
72
+ img = Image.open(info["image"].format(concept))
73
+ description = temaplte.format(CAPTIONS[concept])
74
+ return gr.update(value=img, visible=True), gr.update(value=description, visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(value=concept)
75
+
76
+ def action2():
77
+ concept = "face_lifting" ###################
78
+ img = Image.open(info["image"].format(concept))
79
+ description = temaplte.format(CAPTIONS[concept])
80
+ return gr.update(value=img, visible=True), gr.update(value=description, visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(value=concept)
81
+
82
+ def action3():
83
+ concept = "pendant" #############
84
+ img = Image.open(info["image"].format(concept))
85
+ description = temaplte.format(CAPTIONS[concept])
86
+ return gr.update(value=img, visible=True), gr.update(value=description, visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(value=concept)
87
+
88
+ def action4():
89
+ concept = "car"
90
+ img = Image.open(info["image"].format(concept))
91
+ description = temaplte.format(CAPTIONS[concept])
92
+ return gr.update(value=img, visible=True), gr.update(value=description, visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(value=concept)
93
+
94
+ def action5():
95
+ concept = "body_lotion" #################
96
+ img = Image.open(info["image"].format(concept))
97
+ description = temaplte.format(CAPTIONS[concept])
98
+ return gr.update(value=img, visible=True), gr.update(value=description, visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(value=concept)
99
+
100
+ def action6():
101
+ concept = "toy" ############
102
+ img = Image.open(info["image"].format(concept))
103
+ description = temaplte.format(CAPTIONS[concept])
104
+ return gr.update(value=img, visible=True), gr.update(value=description, visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(value=concept)
105
+
106
+ def action7():
107
+ concept = "bag" ##############
108
+ img = Image.open(info["image"].format(concept))
109
+ description = temaplte.format(CAPTIONS[concept])
110
+ return gr.update(value=img, visible=True), gr.update(value=description, visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(value=concept)
111
+
112
+ def action8():
113
+ concept = "armchair" ############
114
+ img = Image.open(info["image"].format(concept))
115
+ description = temaplte.format(CAPTIONS[concept])
116
+ return gr.update(value=img, visible=True), gr.update(value=description, visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(value=concept)
117
+
118
+
119
+ css2 = """
120
+
121
+ .my-custom-button {
122
+ width: 100px; /* Button size */
123
+ height: 130px;
124
+ padding: 0; /* Remove default padding */
125
+ margin: 0px; /* Optional spacing between buttons */
126
+ display: flex;
127
+ align-items: center;
128
+ justify-content: center;
129
+ background-color: transparent;
130
+ border: none;
131
+ overflow: hidden; /* Ensures the image doesn't overflow */
132
+ --text-xl: 150px
133
+ }
134
+
135
+ .my-custom-button img {
136
+ max-width: 100%;
137
+ max-height: 100%;
138
+ object-fit: contain; /* Ensure icon scales properly */
139
+ }
140
+
141
+ .input_image_container {
142
+ width: 350px !important;
143
+ height: 350px !important;
144
+ overflow: hidden;
145
+ display: flex;
146
+ align-items: center;
147
+ justify-content: center;
148
+ background-color: #f0f0f0;
149
+ }
150
+ .input_image_container img {
151
+ max-width: 100%;
152
+ max-height: 100%;
153
+ width: 350px;
154
+ height: 350px;
155
+ object-fit: contain;
156
+ display: block;
157
+ margin: 0 auto;
158
+ }
159
+
160
+ .prompt input {
161
+ font-size: 20px;
162
+ }
163
+
164
+ .prompt input::placeholder {
165
+ font-size: 20px;
166
+ }
167
+
168
+ .prompt label {
169
+ font-size: 20px !important;
170
+ }
171
+ """
172
+
173
+ def get_demo():
174
+ with gr.Blocks(css="""
175
+ .centered {
176
+ display: flex;
177
+ justify-content: center;
178
+ align-items: center;
179
+ height: 100%;
180
+ }
181
+ .centered img {
182
+ margin: auto;
183
+ object-fit: contain;
184
+ }
185
+ """ + css2) as demo:
186
+ gr.Markdown("## Showcase Commercial Products with Stunning Natural Backgrounds")
187
+ with gr.Row():
188
+ with gr.Column(elem_classes=["centered"]):
189
+ with gr.Row():
190
+ btn1 = gr.Button("", icon=info["image"].format("kettle"), elem_classes=["my-custom-button"])
191
+ btn2 = gr.Button("", icon=info["image"].format("face_lifting"), elem_classes=["my-custom-button"])
192
+ btn3 = gr.Button("", icon=info["image"].format("pendant"), elem_classes=["my-custom-button"])
193
+ btn4 = gr.Button("", icon=info["image"].format("car"), elem_classes=["my-custom-button"])
194
+
195
+ with gr.Row():
196
+ btn5 = gr.Button("", icon=info["image"].format("body_lotion"), elem_classes=["my-custom-button"])
197
+ btn6 = gr.Button("", icon=info["image"].format("toy"), elem_classes=["my-custom-button"])
198
+ btn7 = gr.Button("", icon=info["image"].format("bag"), elem_classes=["my-custom-button"])
199
+ btn8 = gr.Button("", icon=info["image"].format("armchair"), elem_classes=["my-custom-button"])
200
+
201
+ prod_desc = gr.Markdown(value="""<div style="font-size: 20px;">Choose the product you want to showcase </div"> 🠕""", visible=True)
202
+
203
+ input_image = gr.Image(label="Chosen product", type="pil", height=300, width=300, visible=False, interactive=False, container=True, elem_classes=["input_image_container"])
204
+ descr = gr.Markdown(value=temaplte.format(""), visible=False)
205
+
206
+ concept = gr.Textbox("", visible=False)
207
+ prompt = gr.Textbox("", placeholder="is in the cozy kitchen", label="Describe the enviroment for your product", submit_btn=False, max_lines=1, visible=False, elem_classes=["prompt"])
208
+
209
+ btn_generate = gr.Button("Generate images", visible=False)
210
+
211
+
212
+ btn1.click(fn=action1, outputs=[input_image, descr, prod_desc, prompt, btn_generate, concept])
213
+ btn2.click(fn=action2, outputs=[input_image, descr, prod_desc, prompt, btn_generate, concept])
214
+ btn3.click(fn=action3, outputs=[input_image, descr, prod_desc, prompt, btn_generate, concept])
215
+ btn4.click(fn=action4, outputs=[input_image, descr, prod_desc, prompt, btn_generate, concept])
216
+ btn5.click(fn=action5, outputs=[input_image, descr, prod_desc, prompt, btn_generate, concept])
217
+ btn6.click(fn=action6, outputs=[input_image, descr, prod_desc, prompt, btn_generate, concept])
218
+ btn7.click(fn=action7, outputs=[input_image, descr, prod_desc, prompt, btn_generate, concept])
219
+ btn8.click(fn=action8, outputs=[input_image, descr, prod_desc, prompt, btn_generate, concept])
220
+
221
+ with gr.Column():
222
+ with gr.Row():
223
+ res1 = gr.Image(label="Result", visible=True)#, height=450, width=450)
224
+ res2 = gr.Image(label="Result 2", visible=False)
225
+ with gr.Row():
226
+ res3 = gr.Image(label="Result 2", visible=False, height=450, width=450)
227
+ res4 = gr.Image(label="Result 4", visible=False) #<div style="text-align: center; font-size: 18px;">
228
+ gr.Markdown('''
229
+ <div style="display: flex; align-items: center; gap: 10px; font-size: 20px; text-align: center; margin-left: 200px;">
230
+ <div>Made by FusionBrainLab, AIRI</div><img src="https://static.tildacdn.com/tild3633-6662-4437-a333-646631346335/Airinet.png" style="width: 70px; height: auto;">
231
+ </div>
232
+ ''')
233
+
234
+ btn_generate.click(
235
+ fn=generate_image,
236
+ inputs=[concept, prompt],
237
+ outputs=[res1, res2, res3, res4] # font-family: Arial, sans-serif;
238
+ )
239
+
240
+ return demo
241
+
242
+
243
+ if __name__ == '__main__':
244
+ demo = get_demo()
245
+ demo.launch(server_name="0.0.0.0", server_port=7860)
inference_pb2.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Generated by the protocol buffer compiler. DO NOT EDIT!
3
+ # NO CHECKED-IN PROTOBUF GENCODE
4
+ # source: inference.proto
5
+ # Protobuf Python Version: 6.31.0
6
+ """Generated protocol buffer code."""
7
+ from google.protobuf import descriptor as _descriptor
8
+ from google.protobuf import descriptor_pool as _descriptor_pool
9
+ from google.protobuf import runtime_version as _runtime_version
10
+ from google.protobuf import symbol_database as _symbol_database
11
+ from google.protobuf.internal import builder as _builder
12
+ _runtime_version.ValidateProtobufRuntimeVersion(
13
+ _runtime_version.Domain.PUBLIC,
14
+ 6,
15
+ 31,
16
+ 0,
17
+ '',
18
+ 'inference.proto'
19
+ )
20
+ # @@protoc_insertion_point(imports)
21
+
22
+ _sym_db = _symbol_database.Default()
23
+
24
+
25
+
26
+
27
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0finference.proto\x12\tinference\"A\n\x0bLoraRequest\x12\x0f\n\x07\x63oncept\x18\x01 \x01(\t\x12\x0e\n\x06prompt\x18\x02 \x01(\t\x12\x11\n\tuse_cache\x18\x03 \x01(\x08\"F\n\x0cLoraResponse\x12\x0c\n\x04res1\x18\x01 \x01(\x0c\x12\x0c\n\x04res2\x18\x02 \x01(\x0c\x12\x0c\n\x04res3\x18\x03 \x01(\x0c\x12\x0c\n\x04res4\x18\x04 \x01(\x0c\x32J\n\x0bLoraService\x12;\n\x08generate\x12\x16.inference.LoraRequest\x1a\x17.inference.LoraResponseb\x06proto3')
28
+
29
+ _globals = globals()
30
+ _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
31
+ _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'inference_pb2', _globals)
32
+ if not _descriptor._USE_C_DESCRIPTORS:
33
+ DESCRIPTOR._loaded_options = None
34
+ _globals['_LORAREQUEST']._serialized_start=30
35
+ _globals['_LORAREQUEST']._serialized_end=95
36
+ _globals['_LORARESPONSE']._serialized_start=97
37
+ _globals['_LORARESPONSE']._serialized_end=167
38
+ _globals['_LORASERVICE']._serialized_start=169
39
+ _globals['_LORASERVICE']._serialized_end=243
40
+ # @@protoc_insertion_point(module_scope)
inference_pb2_grpc.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
2
+ """Client and server classes corresponding to protobuf-defined services."""
3
+ import grpc
4
+ import warnings
5
+
6
+ import inference_pb2 as inference__pb2
7
+
8
+ GRPC_GENERATED_VERSION = '1.73.1'
9
+ GRPC_VERSION = grpc.__version__
10
+ _version_not_supported = False
11
+
12
+ try:
13
+ from grpc._utilities import first_version_is_lower
14
+ _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
15
+ except ImportError:
16
+ _version_not_supported = True
17
+
18
+ if _version_not_supported:
19
+ raise RuntimeError(
20
+ f'The grpc package installed is at version {GRPC_VERSION},'
21
+ + f' but the generated code in inference_pb2_grpc.py depends on'
22
+ + f' grpcio>={GRPC_GENERATED_VERSION}.'
23
+ + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
24
+ + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
25
+ )
26
+
27
+
28
+ class LoraServiceStub(object):
29
+ """Missing associated documentation comment in .proto file."""
30
+
31
+ def __init__(self, channel):
32
+ """Constructor.
33
+
34
+ Args:
35
+ channel: A grpc.Channel.
36
+ """
37
+ self.generate = channel.unary_unary(
38
+ '/inference.LoraService/generate',
39
+ request_serializer=inference__pb2.LoraRequest.SerializeToString,
40
+ response_deserializer=inference__pb2.LoraResponse.FromString,
41
+ _registered_method=True)
42
+
43
+
44
+ class LoraServiceServicer(object):
45
+ """Missing associated documentation comment in .proto file."""
46
+
47
+ def generate(self, request, context):
48
+ """Missing associated documentation comment in .proto file."""
49
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
50
+ context.set_details('Method not implemented!')
51
+ raise NotImplementedError('Method not implemented!')
52
+
53
+
54
+ def add_LoraServiceServicer_to_server(servicer, server):
55
+ rpc_method_handlers = {
56
+ 'generate': grpc.unary_unary_rpc_method_handler(
57
+ servicer.generate,
58
+ request_deserializer=inference__pb2.LoraRequest.FromString,
59
+ response_serializer=inference__pb2.LoraResponse.SerializeToString,
60
+ ),
61
+ }
62
+ generic_handler = grpc.method_handlers_generic_handler(
63
+ 'inference.LoraService', rpc_method_handlers)
64
+ server.add_generic_rpc_handlers((generic_handler,))
65
+ server.add_registered_method_handlers('inference.LoraService', rpc_method_handlers)
66
+
67
+
68
+ # This class is part of an EXPERIMENTAL API.
69
+ class LoraService(object):
70
+ """Missing associated documentation comment in .proto file."""
71
+
72
+ @staticmethod
73
+ def generate(request,
74
+ target,
75
+ options=(),
76
+ channel_credentials=None,
77
+ call_credentials=None,
78
+ insecure=False,
79
+ compression=None,
80
+ wait_for_ready=None,
81
+ timeout=None,
82
+ metadata=None):
83
+ return grpc.experimental.unary_unary(
84
+ request,
85
+ target,
86
+ '/inference.LoraService/generate',
87
+ inference__pb2.LoraRequest.SerializeToString,
88
+ inference__pb2.LoraResponse.FromString,
89
+ options,
90
+ channel_credentials,
91
+ insecure,
92
+ call_credentials,
93
+ compression,
94
+ wait_for_ready,
95
+ timeout,
96
+ metadata,
97
+ _registered_method=True)