panelforge commited on
Commit
1082e14
·
verified ·
1 Parent(s): 5f66166

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -15
app.py CHANGED
@@ -4,7 +4,9 @@ import random
4
  import torch
5
  import spaces
6
  from diffusers import DiffusionPipeline
7
- from tags import TAGS # Centralized dictionary
 
 
8
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
  torch_dtype = torch.float16 if device == "cuda" else torch.float32
@@ -15,10 +17,15 @@ pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
15
  MAX_SEED = np.iinfo(np.int32).max
16
  MAX_IMAGE_SIZE = 1024
17
 
18
- # Prepare keys for each tag category for UI and loop usage
19
- tag_categories = list(TAGS.keys()) # e.g. ["Participant", "Tribe", "Skin Tone", ...]
20
  tag_checkboxes = [gr.CheckboxGroup(choices=list(TAGS[k].keys()), label=f"{k} Tags") for k in tag_categories]
21
 
 
 
 
 
 
22
  @spaces.GPU
23
  def infer(prompt, negative_prompt, seed, randomize_seed, width, height,
24
  guidance_scale, num_inference_steps, active_tab, *tag_selections,
@@ -28,8 +35,16 @@ def infer(prompt, negative_prompt, seed, randomize_seed, width, height,
28
  final_prompt = f"score_9, score_8_up, score_7_up, source_anime, {prompt}"
29
  else:
30
  combined_tags = []
31
- for (tag_name, tag_dict), selected in zip(TAGS.items(), tag_selections):
32
- combined_tags.extend([tag_dict[tag] for tag in selected])
 
 
 
 
 
 
 
 
33
  tag_string = ", ".join(combined_tags)
34
  final_prompt = f"score_9, score_8_up, score_7_up, source_anime, {tag_string}"
35
 
@@ -117,15 +132,21 @@ with gr.Blocks(css=css) as demo:
117
  tag_box.render()
118
  tag_tab.select(lambda: "Tag Selection", outputs=active_tab)
119
 
120
- run_button.click(
121
- fn=infer,
122
- inputs=[
123
- prompt, negative_prompt, seed, randomize_seed,
124
- width, height, guidance_scale, num_inference_steps,
125
- active_tab,
126
- *tag_checkboxes
127
- ],
128
- outputs=[result, seed, prompt_info]
129
- )
 
 
 
 
 
 
130
 
131
  demo.queue().launch()
 
4
  import torch
5
  import spaces
6
  from diffusers import DiffusionPipeline
7
+
8
+ from tags import TAGS
9
+ import tags_extra
10
 
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
  torch_dtype = torch.float16 if device == "cuda" else torch.float32
 
17
  MAX_SEED = np.iinfo(np.int32).max
18
  MAX_IMAGE_SIZE = 1024
19
 
20
+ # Create checkbox groups for original tags
21
+ tag_categories = list(TAGS.keys())
22
  tag_checkboxes = [gr.CheckboxGroup(choices=list(TAGS[k].keys()), label=f"{k} Tags") for k in tag_categories]
23
 
24
+ # Create checkbox groups for extra tags
25
+ extra_tag_categories = list(tags_extra.TAGS_EXTRA.keys())
26
+ extra_tag_checkboxes = [gr.CheckboxGroup(choices=list(tags_extra.TAGS_EXTRA[k].keys()), label=f"{k} Tags (Extra)") for k in extra_tag_categories]
27
+
28
+
29
  @spaces.GPU
30
  def infer(prompt, negative_prompt, seed, randomize_seed, width, height,
31
  guidance_scale, num_inference_steps, active_tab, *tag_selections,
 
35
  final_prompt = f"score_9, score_8_up, score_7_up, source_anime, {prompt}"
36
  else:
37
  combined_tags = []
38
+
39
+ if active_tab == "Tag Selection":
40
+ for (tag_name, tag_dict), selected in zip(TAGS.items(), tag_selections[:len(TAGS)]):
41
+ combined_tags.extend([tag_dict[tag] for tag in selected])
42
+ elif active_tab == "Extra Tag Selection":
43
+ offset = len(TAGS)
44
+ for (tag_name, tag_dict), selected in zip(tags_extra.TAGS_EXTRA.items(),
45
+ tag_selections[offset:offset+len(tags_extra.TAGS_EXTRA)]):
46
+ combined_tags.extend([tag_dict[tag] for tag in selected])
47
+
48
  tag_string = ", ".join(combined_tags)
49
  final_prompt = f"score_9, score_8_up, score_7_up, source_anime, {tag_string}"
50
 
 
132
  tag_box.render()
133
  tag_tab.select(lambda: "Tag Selection", outputs=active_tab)
134
 
135
+ with gr.TabItem("Extra Tag Selection") as extra_tag_tab:
136
+ for tag_box in extra_tag_checkboxes:
137
+ tag_box.render()
138
+ extra_tag_tab.select(lambda: "Extra Tag Selection", outputs=active_tab)
139
+
140
+ run_button.click(
141
+ fn=infer,
142
+ inputs=[
143
+ prompt, negative_prompt, seed, randomize_seed,
144
+ width, height, guidance_scale, num_inference_steps,
145
+ active_tab,
146
+ *tag_checkboxes,
147
+ *extra_tag_checkboxes
148
+ ],
149
+ outputs=[result, seed, prompt_info]
150
+ )
151
 
152
  demo.queue().launch()