lionelgarnier commited on
Commit
6894e88
·
1 Parent(s): 08f5d28

test mistral + flux

Browse files
Files changed (1) hide show
  1. app.py +29 -7
app.py CHANGED
@@ -4,9 +4,7 @@ import random
4
  import spaces
5
  import torch
6
  from diffusers import DiffusionPipeline
7
- from huggingface_hub import InferenceClient
8
-
9
- client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.3")
10
 
11
 
12
  dtype = torch.bfloat16
@@ -17,6 +15,19 @@ pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", tor
17
  MAX_SEED = np.iinfo(np.int32).max
18
  MAX_IMAGE_SIZE = 2048
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  @spaces.GPU()
21
  def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
22
  if randomize_seed:
@@ -62,7 +73,7 @@ with gr.Blocks(css=css) as demo:
62
  container=False,
63
  )
64
 
65
- run_button = gr.Button("Run", scale=0)
66
 
67
  refined_prompt = gr.Text(
68
  label="Refined Prompt",
@@ -71,10 +82,13 @@ with gr.Blocks(css=css) as demo:
71
  placeholder="Prompt refined by Mistral",
72
  container=False
73
  )
 
 
 
74
 
75
  result = gr.Image(label="Result", show_label=False)
76
 
77
- with gr.Accordion("Advanced Settings istral", open=False):
78
  gr.Slider(
79
  label="Temperature",
80
  value=0.9,
@@ -161,10 +175,18 @@ with gr.Blocks(css=css) as demo:
161
  cache_examples="lazy"
162
  )
163
 
 
 
 
 
 
 
 
 
164
  gr.on(
165
- triggers=[run_button.click, prompt.submit],
166
  fn = infer,
167
- inputs = [prompt, seed, randomize_seed, width, height, num_inference_steps],
168
  outputs = [result, seed]
169
  )
170
 
 
4
  import spaces
5
  import torch
6
  from diffusers import DiffusionPipeline
7
+ from transformers import pipeline
 
 
8
 
9
 
10
  dtype = torch.bfloat16
 
15
  MAX_SEED = np.iinfo(np.int32).max
16
  MAX_IMAGE_SIZE = 2048
17
 
18
+
19
+
20
+
21
+ def refine_prompt(prompt):
22
+ chatbot = pipeline("text-generation", model="mistralai/Mistral-7B-Instruct-v0.3")
23
+ messages = [
24
+ {"role": "system", "content": "You are a product designer. You will get a basic prompt of product request and you need to imagine a new product design to satisfy that need. Produce an extended description of product front view that will be use by Flux to generate a visual"},
25
+ {"role": "user", "content": "a castle schoolbag"},
26
+ ]
27
+ chatbot(messages)
28
+ refined_prompt = chatbot(prompt)
29
+ return refined_prompt
30
+
31
  @spaces.GPU()
32
  def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
33
  if randomize_seed:
 
73
  container=False,
74
  )
75
 
76
+ prompt_button = gr.Button("Refine prompt", scale=0)
77
 
78
  refined_prompt = gr.Text(
79
  label="Refined Prompt",
 
82
  placeholder="Prompt refined by Mistral",
83
  container=False
84
  )
85
+
86
+
87
+ run_button = gr.Button("Create visual", scale=0)
88
 
89
  result = gr.Image(label="Result", show_label=False)
90
 
91
+ with gr.Accordion("Advanced Settings Mistral", open=False):
92
  gr.Slider(
93
  label="Temperature",
94
  value=0.9,
 
175
  cache_examples="lazy"
176
  )
177
 
178
+
179
+ gr.on(
180
+ triggers=[prompt_button.click, prompt.submit],
181
+ fn = refine_prompt,
182
+ inputs = [prompt],
183
+ outputs = [refined_prompt]
184
+ )
185
+
186
  gr.on(
187
+ triggers=[run_button.click],
188
  fn = infer,
189
+ inputs = [refined_prompt, seed, randomize_seed, width, height, num_inference_steps],
190
  outputs = [result, seed]
191
  )
192