|
import spaces |
|
from huggingface_hub import InferenceClient |
|
from PIL import Image |
|
import io |
|
import config |
|
import random |
|
|
|
from diffusers import DiffusionPipeline, AutoPipelineForText2Image |
|
import torch |
|
|
|
class DiffusionInference: |
|
def __init__(self, api_key=None): |
|
""" |
|
Initialize the inference client with the Hugging Face API token. |
|
""" |
|
self.api_key = api_key or config.HF_TOKEN |
|
self.client = InferenceClient( |
|
provider="hf-inference", |
|
api_key=self.api_key, |
|
) |
|
self.device = torch.device("cuda" if torch.cuda else "cpu") |
|
|
|
def text_to_image(self, prompt, model_name=None, negative_prompt=None, seed=None, **kwargs): |
|
""" |
|
Generate an image from a text prompt. |
|
|
|
Args: |
|
prompt (str): The text prompt to guide image generation |
|
model_name (str, optional): The model to use for inference |
|
negative_prompt (str, optional): What not to include in the image |
|
**kwargs: Additional parameters to pass to the model |
|
|
|
Returns: |
|
PIL.Image: The generated image |
|
""" |
|
model = model_name or config.DEFAULT_TEXT2IMG_MODEL |
|
|
|
|
|
params = { |
|
"prompt": prompt, |
|
} |
|
|
|
|
|
|
|
|
|
|
|
if negative_prompt is not None: |
|
params["negative_prompt"] = negative_prompt |
|
|
|
|
|
for k, v in kwargs.items(): |
|
if k not in ["prompt", "model", "negative_prompt"]: |
|
params[k] = v |
|
|
|
try: |
|
|
|
image = self.run_text_to_image_pipeline(model, seed, **params) |
|
return image |
|
except Exception as e: |
|
print(f"Error generating image: {e}") |
|
print(f"Model: {model}") |
|
print(f"Prompt: {prompt}") |
|
raise |
|
|
|
def image_to_image(self, image, prompt=None, model_name=None, negative_prompt=None, **kwargs): |
|
""" |
|
Generate a new image from an input image and optional prompt. |
|
|
|
Args: |
|
image (PIL.Image or str): Input image or path to image |
|
prompt (str, optional): Text prompt to guide the transformation |
|
model_name (str, optional): The model to use for inference |
|
negative_prompt (str, optional): What not to include in the image |
|
**kwargs: Additional parameters to pass to the model |
|
|
|
Returns: |
|
PIL.Image: The generated image |
|
""" |
|
import tempfile |
|
import os |
|
|
|
model = model_name or config.DEFAULT_IMG2IMG_MODEL |
|
|
|
|
|
temp_file = None |
|
try: |
|
|
|
if isinstance(image, str): |
|
|
|
image_path = image |
|
elif isinstance(image, Image.Image): |
|
|
|
temp_dir = tempfile.gettempdir() |
|
temp_file = os.path.join(temp_dir, "temp_image.png") |
|
image.save(temp_file, format="PNG") |
|
image_path = temp_file |
|
else: |
|
|
|
try: |
|
pil_image = Image.fromarray(image) |
|
temp_dir = tempfile.gettempdir() |
|
temp_file = os.path.join(temp_dir, "temp_image.png") |
|
pil_image.save(temp_file, format="PNG") |
|
image_path = temp_file |
|
except Exception as e: |
|
raise ValueError(f"Unsupported image type: {type(image)}. Error: {e}") |
|
|
|
|
|
client = InferenceClient( |
|
provider="hf-inference", |
|
api_key=self.api_key, |
|
) |
|
|
|
|
|
params = {} |
|
|
|
if model is not None: |
|
params["model"] = model |
|
if prompt is not None: |
|
params["prompt"] = prompt |
|
if negative_prompt is not None: |
|
params["negative_prompt"] = negative_prompt |
|
|
|
|
|
for k, v in kwargs.items(): |
|
if v is not None and k not in ["image", "prompt", "model", "negative_prompt"]: |
|
params[k] = v |
|
|
|
|
|
print(f"DEBUG: Calling image_to_image with:") |
|
print(f"- Image path: {image_path}") |
|
print(f"- Parameters: {params}") |
|
|
|
|
|
result = client.image_to_image(image_path, **params) |
|
return result |
|
|
|
except Exception as e: |
|
print(f"Error transforming image: {e}") |
|
print(f"Image type: {type(image)}") |
|
print(f"Model: {model}") |
|
print(f"Prompt: {prompt}") |
|
raise |
|
|
|
finally: |
|
|
|
if temp_file and os.path.exists(temp_file): |
|
try: |
|
os.remove(temp_file) |
|
except Exception as e: |
|
print(f"Warning: Could not delete temporary file {temp_file}: {e}") |
|
|
|
@spaces.GPU |
|
def run_text_to_image_pipeline(self, model_name, seed, **kwargs): |
|
if seed is not None: |
|
try: |
|
|
|
generator = torch.Generator(device=self.device).manual_seed(seed) |
|
except (ValueError, TypeError): |
|
|
|
random_seed = random.randint(0, 3999999999) |
|
generator = torch.Generator(device=self.device).manual_seed(random_seed) |
|
print(f"Warning: Invalid seed value: {seed}, using random seed {random_seed} instead") |
|
else: |
|
|
|
random_seed = random.randint(0, 3999999999) |
|
generator = torch.Generator(device=self.device).manual_seed(random_seed) |
|
print(f"Using random seed: {random_seed}") |
|
pipeline = AutoPipelineForText2Image.from_pretrained(model_name, generator=generator, torch_dtype=torch.float16).to(self.device) |
|
image = pipeline(**kwargs).images[0] |
|
return image |