|
|
import os |
|
|
import argparse |
|
|
from PIL import Image |
|
|
from util import load_models, classify_image, segment_image, apply_mask |
|
|
|
|
|
def process_images(input_folder, output_folder, pattern_image_path, head_image_path): |
|
|
classification_model, segmentation_model = load_models() |
|
|
names = segmentation_model.names |
|
|
|
|
|
pattern_image = Image.open(pattern_image_path) |
|
|
head_image = Image.open(head_image_path).convert("RGBA") |
|
|
|
|
|
for filename in os.listdir(input_folder): |
|
|
if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.webp')): |
|
|
image_path = os.path.join(input_folder, filename) |
|
|
image = Image.open(image_path) |
|
|
|
|
|
|
|
|
category = classify_image(image, classification_model) |
|
|
category_name = classification_model.names[category[0]] |
|
|
|
|
|
|
|
|
segmentation_results = segment_image(image, segmentation_model) |
|
|
try: |
|
|
masks = segmentation_results[0].masks.data.cpu().numpy() |
|
|
class_ids = segmentation_results[0].boxes.cls.cpu().numpy().astype(int) |
|
|
except AttributeError: |
|
|
if category_name in ['porn', 'hentai']: |
|
|
print(f"Warning: {filename} is classified as sensitive content, but no mask was found.") |
|
|
masks = [] |
|
|
class_ids = [] |
|
|
|
|
|
mask_options = [names[class_id] for class_id in class_ids] |
|
|
selected_masks = mask_options |
|
|
|
|
|
if selected_masks: |
|
|
image_with_fill = image.copy() |
|
|
for i, mask in enumerate(masks): |
|
|
if mask_options[i] in selected_masks: |
|
|
image_with_fill = apply_mask(image_with_fill, mask, pattern_image, head_image) |
|
|
|
|
|
|
|
|
output_path = os.path.join(output_folder, f"processed_{filename}") |
|
|
image_with_fill.save(output_path, format="PNG") |
|
|
print(f"Processed and saved: {output_path}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser(description="Batch process images with YOLO models.") |
|
|
parser.add_argument("input_folder", type=str, help="Path to the input folder containing images.") |
|
|
parser.add_argument("output_folder", type=str, help="Path to the output folder to save processed images.") |
|
|
parser.add_argument("--pattern_image", type=str, default="assets/pattern.png", help="Path to the pattern image.") |
|
|
parser.add_argument("--head_image", type=str, default="assets/head.png", help="Path to the head image.") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
os.makedirs(args.output_folder, exist_ok=True) |
|
|
process_images(args.input_folder, args.output_folder, args.pattern_image, args.head_image) |
|
|
|