Spaces:
Running
on
Zero
Running
on
Zero
Upload 2 files
Browse files
env.py
CHANGED
|
@@ -40,7 +40,8 @@ load_diffusers_format_model = [
|
|
| 40 |
'rubbrband/realcartoonRealistic_v14',
|
| 41 |
'KBlueLeaf/Kohaku-XL-Epsilon-rev2',
|
| 42 |
'KBlueLeaf/Kohaku-XL-Epsilon-rev3',
|
| 43 |
-
'
|
|
|
|
| 44 |
'Eugeoter/artiwaifu-diffusion-1.0',
|
| 45 |
'Raelina/Rae-Diffusion-XL-V2',
|
| 46 |
'Raelina/Raemu-XL-V4',
|
|
|
|
| 40 |
'rubbrband/realcartoonRealistic_v14',
|
| 41 |
'KBlueLeaf/Kohaku-XL-Epsilon-rev2',
|
| 42 |
'KBlueLeaf/Kohaku-XL-Epsilon-rev3',
|
| 43 |
+
'KBlueLeaf/Kohaku-XL-Zeta',
|
| 44 |
+
'kayfahaarukku/UrangDiffusion-1.2',
|
| 45 |
'Eugeoter/artiwaifu-diffusion-1.0',
|
| 46 |
'Raelina/Rae-Diffusion-XL-V2',
|
| 47 |
'Raelina/Raemu-XL-V4',
|
tagger.py
CHANGED
|
@@ -12,10 +12,15 @@ from pathlib import Path
|
|
| 12 |
WD_MODEL_NAMES = ["p1atdev/wd-swinv2-tagger-v3-hf"]
|
| 13 |
WD_MODEL_NAME = WD_MODEL_NAMES[0]
|
| 14 |
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
wd_processor = AutoImageProcessor.from_pretrained(WD_MODEL_NAME, trust_remote_code=True)
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
def _people_tag(noun: str, minimum: int = 1, maximum: int = 5):
|
| 21 |
return (
|
|
@@ -506,7 +511,7 @@ def gen_prompt(rating: list[str], character: list[str], general: list[str]):
|
|
| 506 |
return ", ".join(all_tags)
|
| 507 |
|
| 508 |
|
| 509 |
-
@spaces.GPU()
|
| 510 |
def predict_tags(image: Image.Image, general_threshold: float = 0.3, character_threshold: float = 0.8):
|
| 511 |
inputs = wd_processor.preprocess(image, return_tensors="pt")
|
| 512 |
|
|
@@ -514,9 +519,11 @@ def predict_tags(image: Image.Image, general_threshold: float = 0.3, character_t
|
|
| 514 |
logits = torch.sigmoid(outputs.logits[0]) # take the first logits
|
| 515 |
|
| 516 |
# get probabilities
|
|
|
|
| 517 |
results = {
|
| 518 |
wd_model.config.id2label[i]: float(logit.float()) for i, logit in enumerate(logits)
|
| 519 |
}
|
|
|
|
| 520 |
# rating, character, general
|
| 521 |
rating, character, general = postprocess_results(
|
| 522 |
results, general_threshold, character_threshold
|
|
|
|
| 12 |
WD_MODEL_NAMES = ["p1atdev/wd-swinv2-tagger-v3-hf"]
|
| 13 |
WD_MODEL_NAME = WD_MODEL_NAMES[0]
|
| 14 |
|
| 15 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 16 |
+
default_device = device
|
|
|
|
| 17 |
|
| 18 |
+
try:
|
| 19 |
+
wd_model = AutoModelForImageClassification.from_pretrained(WD_MODEL_NAME, trust_remote_code=True).to(default_device).eval()
|
| 20 |
+
wd_processor = AutoImageProcessor.from_pretrained(WD_MODEL_NAME, trust_remote_code=True)
|
| 21 |
+
except Exception as e:
|
| 22 |
+
print(e)
|
| 23 |
+
wd_model = wd_processor = None
|
| 24 |
|
| 25 |
def _people_tag(noun: str, minimum: int = 1, maximum: int = 5):
|
| 26 |
return (
|
|
|
|
| 511 |
return ", ".join(all_tags)
|
| 512 |
|
| 513 |
|
| 514 |
+
@spaces.GPU(duration=30)
|
| 515 |
def predict_tags(image: Image.Image, general_threshold: float = 0.3, character_threshold: float = 0.8):
|
| 516 |
inputs = wd_processor.preprocess(image, return_tensors="pt")
|
| 517 |
|
|
|
|
| 519 |
logits = torch.sigmoid(outputs.logits[0]) # take the first logits
|
| 520 |
|
| 521 |
# get probabilities
|
| 522 |
+
if device != default_device: wd_model.to(device=device)
|
| 523 |
results = {
|
| 524 |
wd_model.config.id2label[i]: float(logit.float()) for i, logit in enumerate(logits)
|
| 525 |
}
|
| 526 |
+
if device != default_device: wd_model.to(device=default_device)
|
| 527 |
# rating, character, general
|
| 528 |
rating, character, general = postprocess_results(
|
| 529 |
results, general_threshold, character_threshold
|