imagencpu / inference.py
ovedrive's picture
make cpu compatible
3bca564
import spaces
from huggingface_hub import InferenceClient
from PIL import Image
import io
import config
import random
from diffusers import DiffusionPipeline, AutoPipelineForText2Image
import torch
class DiffusionInference:
def __init__(self, api_key=None):
"""
Initialize the inference client with the Hugging Face API token.
"""
self.api_key = api_key or config.HF_TOKEN
self.client = InferenceClient(
provider="hf-inference",
api_key=self.api_key,
)
self.device = torch.device("cuda" if torch.cuda else "cpu")
def text_to_image(self, prompt, model_name=None, negative_prompt=None, seed=None, **kwargs):
"""
Generate an image from a text prompt.
Args:
prompt (str): The text prompt to guide image generation
model_name (str, optional): The model to use for inference
negative_prompt (str, optional): What not to include in the image
**kwargs: Additional parameters to pass to the model
Returns:
PIL.Image: The generated image
"""
model = model_name or config.DEFAULT_TEXT2IMG_MODEL
# Create parameters dictionary for all keyword arguments
params = {
"prompt": prompt,
}
# Handle seed parameter
# Add negative prompt if provided
if negative_prompt is not None:
params["negative_prompt"] = negative_prompt
# Add any other parameters
for k, v in kwargs.items():
if k not in ["prompt", "model", "negative_prompt"]:
params[k] = v
try:
# Call the API with all parameters as kwargs
image = self.run_text_to_image_pipeline(model, seed, **params)
return image
except Exception as e:
print(f"Error generating image: {e}")
print(f"Model: {model}")
print(f"Prompt: {prompt}")
raise
def image_to_image(self, image, prompt=None, model_name=None, negative_prompt=None, **kwargs):
"""
Generate a new image from an input image and optional prompt.
Args:
image (PIL.Image or str): Input image or path to image
prompt (str, optional): Text prompt to guide the transformation
model_name (str, optional): The model to use for inference
negative_prompt (str, optional): What not to include in the image
**kwargs: Additional parameters to pass to the model
Returns:
PIL.Image: The generated image
"""
import tempfile
import os
model = model_name or config.DEFAULT_IMG2IMG_MODEL
# Create a temporary file for the image if it's a PIL Image
temp_file = None
try:
# Handle different image input types
if isinstance(image, str):
# If it's already a file path, use it directly
image_path = image
elif isinstance(image, Image.Image):
# If it's a PIL Image, save it to a temporary file
temp_dir = tempfile.gettempdir()
temp_file = os.path.join(temp_dir, "temp_image.png")
image.save(temp_file, format="PNG")
image_path = temp_file
else:
# If it's something else, try to convert it to a PIL Image first
try:
pil_image = Image.fromarray(image)
temp_dir = tempfile.gettempdir()
temp_file = os.path.join(temp_dir, "temp_image.png")
pil_image.save(temp_file, format="PNG")
image_path = temp_file
except Exception as e:
raise ValueError(f"Unsupported image type: {type(image)}. Error: {e}")
# Create a NEW InferenceClient for this call to avoid any potential state issues
client = InferenceClient(
provider="hf-inference",
api_key=self.api_key,
)
# Create the parameter dict with only the non-None values
params = {}
# Only add parameters that are not None
if model is not None:
params["model"] = model
if prompt is not None:
params["prompt"] = prompt
if negative_prompt is not None:
params["negative_prompt"] = negative_prompt
# Add additional kwargs, but filter out any that might create conflicts
for k, v in kwargs.items():
if v is not None and k not in ["image", "prompt", "model", "negative_prompt"]:
params[k] = v
# Debug the parameters we're sending
print(f"DEBUG: Calling image_to_image with:")
print(f"- Image path: {image_path}")
print(f"- Parameters: {params}")
# Make the API call
result = client.image_to_image(image_path, **params)
return result
except Exception as e:
print(f"Error transforming image: {e}")
print(f"Image type: {type(image)}")
print(f"Model: {model}")
print(f"Prompt: {prompt}")
raise
finally:
# Clean up the temporary file if it was created
if temp_file and os.path.exists(temp_file):
try:
os.remove(temp_file)
except Exception as e:
print(f"Warning: Could not delete temporary file {temp_file}: {e}")
@spaces.GPU
def run_text_to_image_pipeline(self, model_name, seed, **kwargs):
if seed is not None:
try:
# Convert to integer and add to params
generator = torch.Generator(device=self.device).manual_seed(seed)
except (ValueError, TypeError):
# Use random seed if conversion fails
random_seed = random.randint(0, 3999999999) # Max 32-bit integer
generator = torch.Generator(device=self.device).manual_seed(random_seed)
print(f"Warning: Invalid seed value: {seed}, using random seed {random_seed} instead")
else:
# Generate random seed when none is provided
random_seed = random.randint(0, 3999999999) # Max 32-bit integer
generator = torch.Generator(device=self.device).manual_seed(random_seed)
print(f"Using random seed: {random_seed}")
pipeline = AutoPipelineForText2Image.from_pretrained(model_name, generator=generator, torch_dtype=torch.float16).to(self.device)
image = pipeline(**kwargs).images[0]
return image