File size: 5,158 Bytes
8247a04 7ddc847 8247a04 7ddc847 8247a04 7ddc847 8247a04 7ddc847 8247a04 7ddc847 8247a04 7ddc847 8247a04 7ddc847 8247a04 7ddc847 8247a04 7ddc847 8247a04 7ddc847 8247a04 7ddc847 8247a04 7ddc847 8247a04 7ddc847 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
from huggingface_hub import InferenceClient
from PIL import Image
import io
import config
class DiffusionInference:
def __init__(self, api_key=None):
"""
Initialize the inference client with the Hugging Face API token.
"""
self.api_key = api_key or config.HF_TOKEN
self.client = InferenceClient(
provider="hf-inference",
api_key=self.api_key,
)
def text_to_image(self, prompt, model_name=None, negative_prompt=None, **kwargs):
"""
Generate an image from a text prompt.
Args:
prompt (str): The text prompt to guide image generation
model_name (str, optional): The model to use for inference
negative_prompt (str, optional): What not to include in the image
**kwargs: Additional parameters to pass to the model
Returns:
PIL.Image: The generated image
"""
model = model_name or config.DEFAULT_TEXT2IMG_MODEL
# Create parameters dictionary for all keyword arguments
params = {
"prompt": prompt,
"model": model
}
# Add negative prompt if provided
if negative_prompt is not None:
params["negative_prompt"] = negative_prompt
# Add any other parameters
for k, v in kwargs.items():
if k not in ["prompt", "model", "negative_prompt"]:
params[k] = v
try:
# Call the API with all parameters as kwargs
image = self.client.text_to_image(**params)
return image
except Exception as e:
print(f"Error generating image: {e}")
print(f"Model: {model}")
print(f"Prompt: {prompt}")
raise
def image_to_image(self, image, prompt=None, model_name=None, negative_prompt=None, **kwargs):
"""
Generate a new image from an input image and optional prompt.
Args:
image (PIL.Image or str): Input image or path to image
prompt (str, optional): Text prompt to guide the transformation
model_name (str, optional): The model to use for inference
negative_prompt (str, optional): What not to include in the image
**kwargs: Additional parameters to pass to the model
Returns:
PIL.Image: The generated image
"""
import tempfile
import os
model = model_name or config.DEFAULT_IMG2IMG_MODEL
# Create a temporary file for the image if it's a PIL Image
temp_file = None
try:
# Handle different image input types
if isinstance(image, str):
# If it's already a file path, use it directly
image_path = image
elif isinstance(image, Image.Image):
# If it's a PIL Image, save it to a temporary file
temp_dir = tempfile.gettempdir()
temp_file = os.path.join(temp_dir, "temp_image.png")
image.save(temp_file, format="PNG")
image_path = temp_file
else:
# If it's something else, try to convert it to a PIL Image first
try:
pil_image = Image.fromarray(image)
temp_dir = tempfile.gettempdir()
temp_file = os.path.join(temp_dir, "temp_image.png")
pil_image.save(temp_file, format="PNG")
image_path = temp_file
except Exception as e:
raise ValueError(f"Unsupported image type: {type(image)}. Error: {e}")
# Create a parameters dictionary including all the required keyword args
params = {"model": model}
# Add prompt if provided (MUST be as a keyword arg, not positional)
if prompt is not None:
params["prompt"] = prompt
# Add negative_prompt if provided
if negative_prompt is not None:
params["negative_prompt"] = negative_prompt
# Add additional parameters
for k, v in kwargs.items():
if k not in ["prompt", "model", "negative_prompt", "image"]:
params[k] = v
# Make the API call with image as the only positional arg, all others as kwargs
result = self.client.image_to_image(image_path, **params)
return result
except Exception as e:
print(f"Error transforming image: {e}")
print(f"Model: {model}")
print(f"Prompt: {prompt}")
raise
finally:
# Clean up the temporary file if it was created
if temp_file and os.path.exists(temp_file):
try:
os.remove(temp_file)
except Exception as e:
print(f"Warning: Could not delete temporary file {temp_file}: {e}")
|