|
from huggingface_hub import InferenceClient |
|
from PIL import Image |
|
import io |
|
import config |
|
|
|
|
|
class DiffusionInference: |
|
def __init__(self, api_key=None): |
|
""" |
|
Initialize the inference client with the Hugging Face API token. |
|
""" |
|
self.api_key = api_key or config.HF_TOKEN |
|
self.client = InferenceClient( |
|
provider="hf-inference", |
|
api_key=self.api_key, |
|
) |
|
|
|
def text_to_image(self, prompt, model_name=None, negative_prompt=None, **kwargs): |
|
""" |
|
Generate an image from a text prompt. |
|
|
|
Args: |
|
prompt (str): The text prompt to guide image generation |
|
model_name (str, optional): The model to use for inference |
|
negative_prompt (str, optional): What not to include in the image |
|
**kwargs: Additional parameters to pass to the model |
|
|
|
Returns: |
|
PIL.Image: The generated image |
|
""" |
|
model = model_name or config.DEFAULT_TEXT2IMG_MODEL |
|
|
|
|
|
params = { |
|
"prompt": prompt, |
|
"model": model |
|
} |
|
|
|
|
|
if negative_prompt is not None: |
|
params["negative_prompt"] = negative_prompt |
|
|
|
|
|
for k, v in kwargs.items(): |
|
if k not in ["prompt", "model", "negative_prompt"]: |
|
params[k] = v |
|
|
|
try: |
|
|
|
image = self.client.text_to_image(**params) |
|
return image |
|
except Exception as e: |
|
print(f"Error generating image: {e}") |
|
print(f"Model: {model}") |
|
print(f"Prompt: {prompt}") |
|
raise |
|
|
|
def image_to_image(self, image, prompt=None, model_name=None, negative_prompt=None, **kwargs): |
|
""" |
|
Generate a new image from an input image and optional prompt. |
|
|
|
Args: |
|
image (PIL.Image or str): Input image or path to image |
|
prompt (str, optional): Text prompt to guide the transformation |
|
model_name (str, optional): The model to use for inference |
|
negative_prompt (str, optional): What not to include in the image |
|
**kwargs: Additional parameters to pass to the model |
|
|
|
Returns: |
|
PIL.Image: The generated image |
|
""" |
|
import tempfile |
|
import os |
|
|
|
model = model_name or config.DEFAULT_IMG2IMG_MODEL |
|
|
|
|
|
temp_file = None |
|
try: |
|
|
|
if isinstance(image, str): |
|
|
|
image_path = image |
|
elif isinstance(image, Image.Image): |
|
|
|
temp_dir = tempfile.gettempdir() |
|
temp_file = os.path.join(temp_dir, "temp_image.png") |
|
image.save(temp_file, format="PNG") |
|
image_path = temp_file |
|
else: |
|
|
|
try: |
|
pil_image = Image.fromarray(image) |
|
temp_dir = tempfile.gettempdir() |
|
temp_file = os.path.join(temp_dir, "temp_image.png") |
|
pil_image.save(temp_file, format="PNG") |
|
image_path = temp_file |
|
except Exception as e: |
|
raise ValueError(f"Unsupported image type: {type(image)}. Error: {e}") |
|
|
|
|
|
client = InferenceClient( |
|
provider="hf-inference", |
|
api_key=self.api_key, |
|
) |
|
|
|
|
|
params = {} |
|
|
|
if model is not None: |
|
params["model"] = model |
|
if prompt is not None: |
|
params["prompt"] = prompt |
|
if negative_prompt is not None: |
|
params["negative_prompt"] = negative_prompt |
|
|
|
|
|
for k, v in kwargs.items(): |
|
if v is not None and k not in ["image", "prompt", "model", "negative_prompt"]: |
|
params[k] = v |
|
|
|
|
|
print(f"DEBUG: Calling image_to_image with:") |
|
print(f"- Image path: {image_path}") |
|
print(f"- Parameters: {params}") |
|
|
|
|
|
result = client.image_to_image(image_path, **params) |
|
return result |
|
|
|
except Exception as e: |
|
print(f"Error transforming image: {e}") |
|
print(f"Image type: {type(image)}") |
|
print(f"Model: {model}") |
|
print(f"Prompt: {prompt}") |
|
raise |
|
|
|
finally: |
|
|
|
if temp_file and os.path.exists(temp_file): |
|
try: |
|
os.remove(temp_file) |
|
except Exception as e: |
|
print(f"Warning: Could not delete temporary file {temp_file}: {e}") |
|
|