|
from huggingface_hub import InferenceClient |
|
from PIL import Image |
|
import io |
|
import config |
|
|
|
|
|
class DiffusionInference: |
|
def __init__(self, api_key=None): |
|
""" |
|
Initialize the inference client with the Hugging Face API token. |
|
""" |
|
self.api_key = api_key or config.HF_TOKEN |
|
self.client = InferenceClient( |
|
provider="hf-inference", |
|
api_key=self.api_key, |
|
) |
|
|
|
def text_to_image(self, prompt, model_name=None, negative_prompt=None, **kwargs): |
|
""" |
|
Generate an image from a text prompt. |
|
|
|
Args: |
|
prompt (str): The text prompt to guide image generation |
|
model_name (str, optional): The model to use for inference |
|
negative_prompt (str, optional): What not to include in the image |
|
**kwargs: Additional parameters to pass to the model |
|
|
|
Returns: |
|
PIL.Image: The generated image |
|
""" |
|
model = model_name or config.DEFAULT_TEXT2IMG_MODEL |
|
|
|
|
|
params = {"prompt": prompt} |
|
|
|
if negative_prompt: |
|
params["negative_prompt"] = negative_prompt |
|
|
|
|
|
params.update(kwargs) |
|
|
|
try: |
|
image = self.client.text_to_image(model=model, **params) |
|
return image |
|
except Exception as e: |
|
print(f"Error generating image: {e}") |
|
raise |
|
|
|
def image_to_image(self, image, prompt=None, model_name=None, negative_prompt=None, **kwargs): |
|
""" |
|
Generate a new image from an input image and optional prompt. |
|
|
|
Args: |
|
image (PIL.Image or str): Input image or path to image |
|
prompt (str, optional): Text prompt to guide the transformation |
|
model_name (str, optional): The model to use for inference |
|
negative_prompt (str, optional): What not to include in the image |
|
**kwargs: Additional parameters to pass to the model |
|
|
|
Returns: |
|
PIL.Image: The generated image |
|
""" |
|
model = model_name or config.DEFAULT_IMG2IMG_MODEL |
|
|
|
|
|
if isinstance(image, str): |
|
image = Image.open(image) |
|
|
|
|
|
params = {"image": image} |
|
|
|
if prompt: |
|
params["prompt"] = prompt |
|
|
|
if negative_prompt: |
|
params["negative_prompt"] = negative_prompt |
|
|
|
|
|
params.update(kwargs) |
|
|
|
try: |
|
result = self.client.image_to_image(model=model, **params) |
|
return result |
|
except Exception as e: |
|
print(f"Error transforming image: {e}") |
|
raise |
|
|