imagencpu / inference.py
Abe
initial copy
8247a04
raw
history blame
2.89 kB
from huggingface_hub import InferenceClient
from PIL import Image
import io
import config
class DiffusionInference:
def __init__(self, api_key=None):
"""
Initialize the inference client with the Hugging Face API token.
"""
self.api_key = api_key or config.HF_TOKEN
self.client = InferenceClient(
provider="hf-inference",
api_key=self.api_key,
)
def text_to_image(self, prompt, model_name=None, negative_prompt=None, **kwargs):
"""
Generate an image from a text prompt.
Args:
prompt (str): The text prompt to guide image generation
model_name (str, optional): The model to use for inference
negative_prompt (str, optional): What not to include in the image
**kwargs: Additional parameters to pass to the model
Returns:
PIL.Image: The generated image
"""
model = model_name or config.DEFAULT_TEXT2IMG_MODEL
# Set up parameters dictionary
params = {"prompt": prompt}
if negative_prompt:
params["negative_prompt"] = negative_prompt
# Add any additional parameters
params.update(kwargs)
try:
image = self.client.text_to_image(model=model, **params)
return image
except Exception as e:
print(f"Error generating image: {e}")
raise
def image_to_image(self, image, prompt=None, model_name=None, negative_prompt=None, **kwargs):
"""
Generate a new image from an input image and optional prompt.
Args:
image (PIL.Image or str): Input image or path to image
prompt (str, optional): Text prompt to guide the transformation
model_name (str, optional): The model to use for inference
negative_prompt (str, optional): What not to include in the image
**kwargs: Additional parameters to pass to the model
Returns:
PIL.Image: The generated image
"""
model = model_name or config.DEFAULT_IMG2IMG_MODEL
# Convert image path to PIL Image if needed
if isinstance(image, str):
image = Image.open(image)
# Set up parameters dictionary
params = {"image": image}
if prompt:
params["prompt"] = prompt
if negative_prompt:
params["negative_prompt"] = negative_prompt
# Add any additional parameters
params.update(kwargs)
try:
result = self.client.image_to_image(model=model, **params)
return result
except Exception as e:
print(f"Error transforming image: {e}")
raise