import spaces from huggingface_hub import InferenceClient from PIL import Image import io import config import random from diffusers import DiffusionPipeline, AutoPipelineForText2Image import torch class DiffusionInference: def __init__(self, api_key=None): """ Initialize the inference client with the Hugging Face API token. """ self.api_key = api_key or config.HF_TOKEN self.client = InferenceClient( provider="hf-inference", api_key=self.api_key, ) self.device = torch.device("cuda" if torch.cuda else "cpu") def text_to_image(self, prompt, model_name=None, negative_prompt=None, seed=None, **kwargs): """ Generate an image from a text prompt. Args: prompt (str): The text prompt to guide image generation model_name (str, optional): The model to use for inference negative_prompt (str, optional): What not to include in the image **kwargs: Additional parameters to pass to the model Returns: PIL.Image: The generated image """ model = model_name or config.DEFAULT_TEXT2IMG_MODEL # Create parameters dictionary for all keyword arguments params = { "prompt": prompt, } # Handle seed parameter # Add negative prompt if provided if negative_prompt is not None: params["negative_prompt"] = negative_prompt # Add any other parameters for k, v in kwargs.items(): if k not in ["prompt", "model", "negative_prompt"]: params[k] = v try: # Call the API with all parameters as kwargs image = self.run_text_to_image_pipeline(model, seed, **params) return image except Exception as e: print(f"Error generating image: {e}") print(f"Model: {model}") print(f"Prompt: {prompt}") raise def image_to_image(self, image, prompt=None, model_name=None, negative_prompt=None, **kwargs): """ Generate a new image from an input image and optional prompt. Args: image (PIL.Image or str): Input image or path to image prompt (str, optional): Text prompt to guide the transformation model_name (str, optional): The model to use for inference negative_prompt (str, optional): What not to include in the image **kwargs: Additional parameters to pass to the model Returns: PIL.Image: The generated image """ import tempfile import os model = model_name or config.DEFAULT_IMG2IMG_MODEL # Create a temporary file for the image if it's a PIL Image temp_file = None try: # Handle different image input types if isinstance(image, str): # If it's already a file path, use it directly image_path = image elif isinstance(image, Image.Image): # If it's a PIL Image, save it to a temporary file temp_dir = tempfile.gettempdir() temp_file = os.path.join(temp_dir, "temp_image.png") image.save(temp_file, format="PNG") image_path = temp_file else: # If it's something else, try to convert it to a PIL Image first try: pil_image = Image.fromarray(image) temp_dir = tempfile.gettempdir() temp_file = os.path.join(temp_dir, "temp_image.png") pil_image.save(temp_file, format="PNG") image_path = temp_file except Exception as e: raise ValueError(f"Unsupported image type: {type(image)}. Error: {e}") # Create a NEW InferenceClient for this call to avoid any potential state issues client = InferenceClient( provider="hf-inference", api_key=self.api_key, ) # Create the parameter dict with only the non-None values params = {} # Only add parameters that are not None if model is not None: params["model"] = model if prompt is not None: params["prompt"] = prompt if negative_prompt is not None: params["negative_prompt"] = negative_prompt # Add additional kwargs, but filter out any that might create conflicts for k, v in kwargs.items(): if v is not None and k not in ["image", "prompt", "model", "negative_prompt"]: params[k] = v # Debug the parameters we're sending print(f"DEBUG: Calling image_to_image with:") print(f"- Image path: {image_path}") print(f"- Parameters: {params}") # Make the API call result = client.image_to_image(image_path, **params) return result except Exception as e: print(f"Error transforming image: {e}") print(f"Image type: {type(image)}") print(f"Model: {model}") print(f"Prompt: {prompt}") raise finally: # Clean up the temporary file if it was created if temp_file and os.path.exists(temp_file): try: os.remove(temp_file) except Exception as e: print(f"Warning: Could not delete temporary file {temp_file}: {e}") @spaces.GPU def run_text_to_image_pipeline(self, model_name, seed, **kwargs): if seed is not None: try: # Convert to integer and add to params generator = torch.Generator(device=self.device).manual_seed(seed) except (ValueError, TypeError): # Use random seed if conversion fails random_seed = random.randint(0, 3999999999) # Max 32-bit integer generator = torch.Generator(device=self.device).manual_seed(random_seed) print(f"Warning: Invalid seed value: {seed}, using random seed {random_seed} instead") else: # Generate random seed when none is provided random_seed = random.randint(0, 3999999999) # Max 32-bit integer generator = torch.Generator(device=self.device).manual_seed(random_seed) print(f"Using random seed: {random_seed}") pipeline = AutoPipelineForText2Image.from_pretrained(model_name, generator=generator, torch_dtype=torch.float16).to(self.device) image = pipeline(**kwargs).images[0] return image