panelforge commited on
Commit
83fcdda
·
verified ·
1 Parent(s): 5598721

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -38
app.py CHANGED
@@ -9,6 +9,14 @@ from tags_straight import TAGS_STRAIGHT
9
  from tags_lesbian import TAGS_LESBIAN
10
  from tags_gay import TAGS_GAY
11
 
 
 
 
 
 
 
 
 
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
  torch_dtype = torch.float16 if device == "cuda" else torch.float32
14
 
@@ -29,44 +37,27 @@ gay_checkboxes, gay_categories = create_checkboxes(TAGS_GAY, "Gay")
29
 
30
 
31
  @spaces.GPU
32
- def infer(prompt, negative_prompt, seed, randomize_seed, width, height,
33
- guidance_scale, num_inference_steps, active_tab, *tag_selections,
34
- progress=gr.Progress(track_tqdm=True)):
35
-
36
- negative_base = "worst quality, bad quality, jpeg artifacts, source_cartoon, 3d, (censor), monochrome, blurry, lowres, watermark"
37
- full_negative_prompt = f"{negative_base}, {negative_prompt}"
38
-
39
- if randomize_seed:
40
- seed = random.randint(0, MAX_SEED)
41
-
42
- generator = torch.Generator().manual_seed(seed)
43
-
44
- if active_tab == "Prompt Input":
45
- final_prompt = f"score_9, score_8_up, score_7_up, source_anime, {prompt}"
46
-
47
- else:
48
- combined_tags = []
49
- # The tag_selections come in order: straight, lesbian, gay
50
- if active_tab == "Straight":
51
- # slice first len(straight_checkboxes) from tag_selections
52
- selected_sets = tag_selections[:len(straight_checkboxes)]
53
- for cat, selected in zip(straight_categories, selected_sets):
54
- combined_tags.extend([TAGS_STRAIGHT[cat][tag] for tag in selected])
55
-
56
- elif active_tab == "Lesbian":
57
- offset = len(straight_checkboxes)
58
- selected_sets = tag_selections[offset:offset + len(lesbian_checkboxes)]
59
- for cat, selected in zip(lesbian_categories, selected_sets):
60
- combined_tags.extend([TAGS_LESBIAN[cat][tag] for tag in selected])
61
-
62
- elif active_tab == "Gay":
63
- offset = len(straight_checkboxes) + len(lesbian_checkboxes)
64
- selected_sets = tag_selections[offset:offset + len(gay_checkboxes)]
65
- for cat, selected in zip(gay_categories, selected_sets):
66
- combined_tags.extend([TAGS_GAY[cat][tag] for tag in selected])
67
-
68
- tag_string = ", ".join(combined_tags)
69
- final_prompt = f"score_9, score_8_up, score_7_up, source_anime, {tag_string}"
70
 
71
  image = pipe(
72
  prompt=final_prompt,
 
9
  from tags_lesbian import TAGS_LESBIAN
10
  from tags_gay import TAGS_GAY
11
 
12
+ PROMPT_PREFIXES = {
13
+ "Prompt Input": "score_9, score_8_up, score_7_up, source_anime",
14
+ "Straight": "score_9, score_8_up, score_7_up, source_anime, ",
15
+ "Lesbian": "score_9, score_8_up, score_7_up, source_anime, ",
16
+ "Gay": "score_9, score_8_up, score_7_up, source_anime, yaoi, "
17
+ # Add more tabs if needed
18
+ }
19
+
20
  device = "cuda" if torch.cuda.is_available() else "cpu"
21
  torch_dtype = torch.float16 if device == "cuda" else torch.float32
22
 
 
37
 
38
 
39
  @spaces.GPU
40
+ prefix = PROMPT_PREFIXES.get(active_tab, "score_9, score_8_up, score_7_up, source_anime")
41
+
42
+ if active_tab == "Prompt Input":
43
+ final_prompt = f"{prefix}, {prompt}"
44
+ else:
45
+ combined_tags = []
46
+
47
+ if active_tab == "Straight":
48
+ for (tag_name, tag_dict), selected in zip(TAGS_STRAIGHT.items(), tag_selections[:len(TAGS_STRAIGHT)]):
49
+ combined_tags.extend([tag_dict[tag] for tag in selected])
50
+ elif active_tab == "Lesbian":
51
+ offset = len(TAGS_STRAIGHT)
52
+ for (tag_name, tag_dict), selected in zip(TAGS_LESBIAN.items(), tag_selections[offset:offset+len(TAGS_LESBIAN)]):
53
+ combined_tags.extend([tag_dict[tag] for tag in selected])
54
+ elif active_tab == "Gay":
55
+ offset = len(TAGS_STRAIGHT) + len(TAGS_LESBIAN)
56
+ for (tag_name, tag_dict), selected in zip(TAGS_GAY.items(), tag_selections[offset:offset+len(TAGS_GAY)]):
57
+ combined_tags.extend([tag_dict[tag] for tag in selected])
58
+
59
+ tag_string = ", ".join(combined_tags)
60
+ final_prompt = f"{prefix}, {tag_string}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  image = pipe(
63
  prompt=final_prompt,