panelforge commited on
Commit
ab8a2ac
·
verified ·
1 Parent(s): e6ca68d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -35
app.py CHANGED
@@ -5,8 +5,9 @@ import torch
5
  import spaces
6
  from diffusers import DiffusionPipeline
7
 
8
- from tags import TAGS
9
- import tags_extra
 
10
 
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
  torch_dtype = torch.float16 if device == "cuda" else torch.float32
@@ -17,13 +18,14 @@ pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
17
  MAX_SEED = np.iinfo(np.int32).max
18
  MAX_IMAGE_SIZE = 1024
19
 
20
- # Create checkbox groups for original tags
21
- tag_categories = list(TAGS.keys())
22
- tag_checkboxes = [gr.CheckboxGroup(choices=list(TAGS[k].keys()), label=f"{k} Tags") for k in tag_categories]
 
23
 
24
- # Create checkbox groups for extra tags
25
- extra_tag_categories = list(tags_extra.TAGS_EXTRA.keys())
26
- extra_tag_checkboxes = [gr.CheckboxGroup(choices=list(tags_extra.TAGS_EXTRA[k].keys()), label=f"{k} Tags (Extra)") for k in extra_tag_categories]
27
 
28
 
29
  @spaces.GPU
@@ -31,23 +33,6 @@ def infer(prompt, negative_prompt, seed, randomize_seed, width, height,
31
  guidance_scale, num_inference_steps, active_tab, *tag_selections,
32
  progress=gr.Progress(track_tqdm=True)):
33
 
34
- if active_tab == "Prompt Input":
35
- final_prompt = f"score_9, score_8_up, score_7_up, source_anime, {prompt}"
36
- else:
37
- combined_tags = []
38
-
39
- if active_tab == "Tag Selection":
40
- for (tag_name, tag_dict), selected in zip(TAGS.items(), tag_selections[:len(TAGS)]):
41
- combined_tags.extend([tag_dict[tag] for tag in selected])
42
- elif active_tab == "Extra Tag Selection":
43
- offset = len(TAGS)
44
- for (tag_name, tag_dict), selected in zip(tags_extra.TAGS_EXTRA.items(),
45
- tag_selections[offset:offset+len(tags_extra.TAGS_EXTRA)]):
46
- combined_tags.extend([tag_dict[tag] for tag in selected])
47
-
48
- tag_string = ", ".join(combined_tags)
49
- final_prompt = f"score_9, score_8_up, score_7_up, source_anime, {tag_string}"
50
-
51
  negative_base = "worst quality, bad quality, jpeg artifacts, source_cartoon, 3d, (censor), monochrome, blurry, lowres, watermark"
52
  full_negative_prompt = f"{negative_base}, {negative_prompt}"
53
 
@@ -56,6 +41,33 @@ def infer(prompt, negative_prompt, seed, randomize_seed, width, height,
56
 
57
  generator = torch.Generator().manual_seed(seed)
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  image = pipe(
60
  prompt=final_prompt,
61
  negative_prompt=full_negative_prompt,
@@ -127,15 +139,20 @@ with gr.Blocks(css=css) as demo:
127
  prompt = gr.Textbox(label="Prompt", lines=3, placeholder="Enter your prompt")
128
  prompt_tab.select(lambda: "Prompt Input", outputs=active_tab)
129
 
130
- with gr.TabItem("Tag Selection") as tag_tab:
131
- for tag_box in tag_checkboxes:
132
- tag_box.render()
133
- tag_tab.select(lambda: "Tag Selection", outputs=active_tab)
 
 
 
 
 
134
 
135
- with gr.TabItem("Extra Tag Selection") as extra_tag_tab:
136
- for tag_box in extra_tag_checkboxes:
137
- tag_box.render()
138
- extra_tag_tab.select(lambda: "Extra Tag Selection", outputs=active_tab)
139
 
140
  run_button.click(
141
  fn=infer,
@@ -143,8 +160,9 @@ with gr.Blocks(css=css) as demo:
143
  prompt, negative_prompt, seed, randomize_seed,
144
  width, height, guidance_scale, num_inference_steps,
145
  active_tab,
146
- *tag_checkboxes,
147
- *extra_tag_checkboxes
 
148
  ],
149
  outputs=[result, seed, prompt_info]
150
  )
 
5
  import spaces
6
  from diffusers import DiffusionPipeline
7
 
8
+ from tags_straight import TAGS_STRAIGHT
9
+ from tags_lesbian import TAGS_LESBIAN
10
+ from tags_gay import TAGS_GAY
11
 
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
  torch_dtype = torch.float16 if device == "cuda" else torch.float32
 
18
  MAX_SEED = np.iinfo(np.int32).max
19
  MAX_IMAGE_SIZE = 1024
20
 
21
+ # Create checkbox groups for each tag set
22
+ def create_checkboxes(tag_dict, suffix):
23
+ categories = list(tag_dict.keys())
24
+ return [gr.CheckboxGroup(choices=list(tag_dict[cat].keys()), label=f"{cat} Tags ({suffix})") for cat in categories], categories
25
 
26
+ straight_checkboxes, straight_categories = create_checkboxes(TAGS_STRAIGHT, "Straight")
27
+ lesbian_checkboxes, lesbian_categories = create_checkboxes(TAGS_LESBIAN, "Lesbian")
28
+ gay_checkboxes, gay_categories = create_checkboxes(TAGS_GAY, "Gay")
29
 
30
 
31
  @spaces.GPU
 
33
  guidance_scale, num_inference_steps, active_tab, *tag_selections,
34
  progress=gr.Progress(track_tqdm=True)):
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  negative_base = "worst quality, bad quality, jpeg artifacts, source_cartoon, 3d, (censor), monochrome, blurry, lowres, watermark"
37
  full_negative_prompt = f"{negative_base}, {negative_prompt}"
38
 
 
41
 
42
  generator = torch.Generator().manual_seed(seed)
43
 
44
+ if active_tab == "Prompt Input":
45
+ final_prompt = f"score_9, score_8_up, score_7_up, source_anime, {prompt}"
46
+
47
+ else:
48
+ combined_tags = []
49
+ # The tag_selections come in order: straight, lesbian, gay
50
+ if active_tab == "Straight":
51
+ # slice first len(straight_checkboxes) from tag_selections
52
+ selected_sets = tag_selections[:len(straight_checkboxes)]
53
+ for cat, selected in zip(straight_categories, selected_sets):
54
+ combined_tags.extend([TAGS_STRAIGHT[cat][tag] for tag in selected])
55
+
56
+ elif active_tab == "Lesbian":
57
+ offset = len(straight_checkboxes)
58
+ selected_sets = tag_selections[offset:offset + len(lesbian_checkboxes)]
59
+ for cat, selected in zip(lesbian_categories, selected_sets):
60
+ combined_tags.extend([TAGS_LESBIAN[cat][tag] for tag in selected])
61
+
62
+ elif active_tab == "Gay":
63
+ offset = len(straight_checkboxes) + len(lesbian_checkboxes)
64
+ selected_sets = tag_selections[offset:offset + len(gay_checkboxes)]
65
+ for cat, selected in zip(gay_categories, selected_sets):
66
+ combined_tags.extend([TAGS_GAY[cat][tag] for tag in selected])
67
+
68
+ tag_string = ", ".join(combined_tags)
69
+ final_prompt = f"score_9, score_8_up, score_7_up, source_anime, {tag_string}"
70
+
71
  image = pipe(
72
  prompt=final_prompt,
73
  negative_prompt=full_negative_prompt,
 
139
  prompt = gr.Textbox(label="Prompt", lines=3, placeholder="Enter your prompt")
140
  prompt_tab.select(lambda: "Prompt Input", outputs=active_tab)
141
 
142
+ with gr.TabItem("Straight") as straight_tab:
143
+ for cb in straight_checkboxes:
144
+ cb.render()
145
+ straight_tab.select(lambda: "Straight", outputs=active_tab)
146
+
147
+ with gr.TabItem("Lesbian") as lesbian_tab:
148
+ for cb in lesbian_checkboxes:
149
+ cb.render()
150
+ lesbian_tab.select(lambda: "Lesbian", outputs=active_tab)
151
 
152
+ with gr.TabItem("Gay") as gay_tab:
153
+ for cb in gay_checkboxes:
154
+ cb.render()
155
+ gay_tab.select(lambda: "Gay", outputs=active_tab)
156
 
157
  run_button.click(
158
  fn=infer,
 
160
  prompt, negative_prompt, seed, randomize_seed,
161
  width, height, guidance_scale, num_inference_steps,
162
  active_tab,
163
+ *straight_checkboxes,
164
+ *lesbian_checkboxes,
165
+ *gay_checkboxes
166
  ],
167
  outputs=[result, seed, prompt_info]
168
  )