File size: 5,599 Bytes
8247a04 7ddc847 8247a04 7ddc847 8247a04 7ddc847 8247a04 7ddc847 8247a04 7ddc847 8247a04 7ddc847 8247a04 7ddc847 8247a04 879971e 8247a04 879971e 7ddc847 879971e 7ddc847 879971e 7ddc847 879971e 8247a04 879971e 8247a04 7ddc847 8247a04 879971e 7ddc847 8247a04 7ddc847 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
from huggingface_hub import InferenceClient
from PIL import Image
import io
import config
class DiffusionInference:
def __init__(self, api_key=None):
"""
Initialize the inference client with the Hugging Face API token.
"""
self.api_key = api_key or config.HF_TOKEN
self.client = InferenceClient(
provider="hf-inference",
api_key=self.api_key,
)
def text_to_image(self, prompt, model_name=None, negative_prompt=None, **kwargs):
"""
Generate an image from a text prompt.
Args:
prompt (str): The text prompt to guide image generation
model_name (str, optional): The model to use for inference
negative_prompt (str, optional): What not to include in the image
**kwargs: Additional parameters to pass to the model
Returns:
PIL.Image: The generated image
"""
model = model_name or config.DEFAULT_TEXT2IMG_MODEL
# Create parameters dictionary for all keyword arguments
params = {
"prompt": prompt,
"model": model
}
# Add negative prompt if provided
if negative_prompt is not None:
params["negative_prompt"] = negative_prompt
# Add any other parameters
for k, v in kwargs.items():
if k not in ["prompt", "model", "negative_prompt"]:
params[k] = v
try:
# Call the API with all parameters as kwargs
image = self.client.text_to_image(**params)
return image
except Exception as e:
print(f"Error generating image: {e}")
print(f"Model: {model}")
print(f"Prompt: {prompt}")
raise
def image_to_image(self, image, prompt=None, model_name=None, negative_prompt=None, **kwargs):
"""
Generate a new image from an input image and optional prompt.
Args:
image (PIL.Image or str): Input image or path to image
prompt (str, optional): Text prompt to guide the transformation
model_name (str, optional): The model to use for inference
negative_prompt (str, optional): What not to include in the image
**kwargs: Additional parameters to pass to the model
Returns:
PIL.Image: The generated image
"""
import tempfile
import os
model = model_name or config.DEFAULT_IMG2IMG_MODEL
# Create a temporary file for the image if it's a PIL Image
temp_file = None
try:
# Handle different image input types
if isinstance(image, str):
# If it's already a file path, use it directly
image_path = image
elif isinstance(image, Image.Image):
# If it's a PIL Image, save it to a temporary file
temp_dir = tempfile.gettempdir()
temp_file = os.path.join(temp_dir, "temp_image.png")
image.save(temp_file, format="PNG")
image_path = temp_file
else:
# If it's something else, try to convert it to a PIL Image first
try:
pil_image = Image.fromarray(image)
temp_dir = tempfile.gettempdir()
temp_file = os.path.join(temp_dir, "temp_image.png")
pil_image.save(temp_file, format="PNG")
image_path = temp_file
except Exception as e:
raise ValueError(f"Unsupported image type: {type(image)}. Error: {e}")
# Create a NEW InferenceClient for this call to avoid any potential state issues
client = InferenceClient(
provider="hf-inference",
api_key=self.api_key,
)
# Create the parameter dict with only the non-None values
params = {}
# Only add parameters that are not None
if model is not None:
params["model"] = model
if prompt is not None:
params["prompt"] = prompt
if negative_prompt is not None:
params["negative_prompt"] = negative_prompt
# Add additional kwargs, but filter out any that might create conflicts
for k, v in kwargs.items():
if v is not None and k not in ["image", "prompt", "model", "negative_prompt"]:
params[k] = v
# Debug the parameters we're sending
print(f"DEBUG: Calling image_to_image with:")
print(f"- Image path: {image_path}")
print(f"- Parameters: {params}")
# Make the API call
result = client.image_to_image(image_path, **params)
return result
except Exception as e:
print(f"Error transforming image: {e}")
print(f"Image type: {type(image)}")
print(f"Model: {model}")
print(f"Prompt: {prompt}")
raise
finally:
# Clean up the temporary file if it was created
if temp_file and os.path.exists(temp_file):
try:
os.remove(temp_file)
except Exception as e:
print(f"Warning: Could not delete temporary file {temp_file}: {e}")
|