Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	| import gradio as gr | |
| import torch | |
| from diffusers import AutoPipelineForText2Image | |
| from transformers import BlipProcessor, BlipForConditionalGeneration | |
| from pathlib import Path | |
| import stone | |
| import requests | |
| import io | |
| import os | |
| from PIL import Image | |
| import spaces | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| from matplotlib.colors import hex2color | |
| pipeline_text2image = AutoPipelineForText2Image.from_pretrained( | |
| "stabilityai/sdxl-turbo", | |
| torch_dtype=torch.float16, | |
| variant="fp16", | |
| ) | |
| pipeline_text2image = pipeline_text2image.to("cuda") | |
| def getimgen(prompt): | |
| return pipeline_text2image( | |
| prompt=prompt, | |
| guidance_scale=0.0, | |
| num_inference_steps=2 | |
| ).images[0] | |
| blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large") | |
| blip_model = BlipForConditionalGeneration.from_pretrained( | |
| "Salesforce/blip-image-captioning-large", | |
| torch_dtype=torch.float16 | |
| ).to("cuda") | |
| def blip_caption_image(image, prefix): | |
| inputs = blip_processor(image, prefix, return_tensors="pt").to("cuda", torch.float16) | |
| out = blip_model.generate(**inputs) | |
| return blip_processor.decode(out[0], skip_special_tokens=True) | |
| def genderfromcaption(caption): | |
| cc = caption.split() | |
| if "man" in cc or "boy" in cc: | |
| return "Man" | |
| elif "woman" in cc or "girl" in cc: | |
| return "Woman" | |
| return "Unsure" | |
| def genderplot(genlist): | |
| order = ["Man", "Woman", "Unsure"] | |
| # Sort the list based on the order of keys | |
| words = sorted(genlist, key=lambda x: order.index(x)) | |
| # Define colors for each category | |
| colors = {"Man": "lightgreen", "Woman": "darkgreen", "Unsure": "lightgrey"} | |
| # Map each word to its corresponding color | |
| word_colors = [colors[word] for word in words] | |
| # Plot the colors in a grid with reduced spacing | |
| fig, axes = plt.subplots(2, 5, figsize=(5,5)) | |
| # Adjust spacing between subplots | |
| plt.subplots_adjust(hspace=0.1, wspace=0.1) | |
| for i, ax in enumerate(axes.flat): | |
| ax.set_axis_off() | |
| ax.add_patch(plt.Rectangle((0, 0), 1, 1, color=word_colors[i])) | |
| return fig | |
| def skintoneplot(hex_codes): | |
| # Convert hex codes to RGB values | |
| rgb_values = [hex2color(hex_code) for hex_code in hex_codes] | |
| # Calculate luminance for each color | |
| luminance_values = [0.299 * r + 0.587 * g + 0.114 * b for r, g, b in rgb_values] | |
| # Sort hex codes based on luminance in descending order (dark to light) | |
| sorted_hex_codes = [code for _, code in sorted(zip(luminance_values, hex_codes), reverse=True)] | |
| # Plot the colors in a grid with reduced spacing | |
| fig, axes = plt.subplots(2, 5, figsize=(5,5)) | |
| # Adjust spacing between subplots | |
| plt.subplots_adjust(hspace=0.1, wspace=0.1) | |
| for i, ax in enumerate(axes.flat): | |
| ax.set_axis_off() | |
| ax.add_patch(plt.Rectangle((0, 0), 1, 1, color=sorted_hex_codes[i])) | |
| return fig | |
| def generate_images_plots(prompt): | |
| foldername = "temp" | |
| # Generate 10 images | |
| images = [getimgen(prompt) for _ in range(10)] | |
| Path(foldername).mkdir(parents=True, exist_ok=True) | |
| genders = [] | |
| skintones = [] | |
| for image, i in zip(images, range(10)): | |
| prompt_prefix = "photo of a " | |
| caption = blip_caption_image(image, prefix=prompt_prefix) | |
| image.save(f"{foldername}/image_{i}.png") | |
| try: | |
| skintoneres = stone.process(f"{foldername}/image_{i}.png", return_report_image=False) | |
| tone = skintoneres['faces'][0]['dominant_colors'][0]['color'] | |
| skintones.append(tone) | |
| except: | |
| skintones.append(None) | |
| genders.append(genderfromcaption(caption)) | |
| print(genders, skintones) | |
| return images, skintoneplot(skintones), genderplot(genders) | |
| with gr.Blocks(title = "Skin Tone and Gender bias in SDXL Demo - Inference API") as demo: | |
| gr.Markdown("# Skin Tone and Gender bias in SDXL Demo") | |
| prompt = gr.Textbox(label="Enter the Prompt") | |
| gallery = gr.Gallery(label="Generated images", show_label=False, elem_id="gallery", | |
| columns=[5], rows=[2], object_fit="contain", height="auto") | |
| btn = gr.Button("Generate images", scale=0) | |
| with gr.Row(equal_height=True): | |
| skinplot = gr.Plot(label="Skin Tone") | |
| genplot = gr.Plot(label="Gender") | |
| btn.click(generate_images_plots, inputs = prompt, outputs = [gallery, skinplot, genplot]) | |
| demo.launch(debug=True) | |