File size: 5,599 Bytes
8247a04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ddc847
 
 
 
 
8247a04
7ddc847
 
8247a04
7ddc847
 
 
 
 
8247a04
 
7ddc847
 
8247a04
 
 
7ddc847
 
8247a04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ddc847
 
 
8247a04
 
7ddc847
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8247a04
879971e
 
 
 
 
8247a04
879971e
 
 
 
 
7ddc847
 
 
 
879971e
 
7ddc847
879971e
7ddc847
879971e
 
 
 
 
8247a04
879971e
 
8247a04
7ddc847
8247a04
 
879971e
7ddc847
 
8247a04
7ddc847
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
from huggingface_hub import InferenceClient
from PIL import Image
import io
import config


class DiffusionInference:
    def __init__(self, api_key=None):
        """
        Initialize the inference client with the Hugging Face API token.
        """
        self.api_key = api_key or config.HF_TOKEN
        self.client = InferenceClient(
            provider="hf-inference",
            api_key=self.api_key,
        )

    def text_to_image(self, prompt, model_name=None, negative_prompt=None, **kwargs):
        """
        Generate an image from a text prompt.
        
        Args:
            prompt (str): The text prompt to guide image generation
            model_name (str, optional): The model to use for inference
            negative_prompt (str, optional): What not to include in the image
            **kwargs: Additional parameters to pass to the model
            
        Returns:
            PIL.Image: The generated image
        """
        model = model_name or config.DEFAULT_TEXT2IMG_MODEL
        
        # Create parameters dictionary for all keyword arguments
        params = {
            "prompt": prompt,
            "model": model
        }
        
        # Add negative prompt if provided
        if negative_prompt is not None:
            params["negative_prompt"] = negative_prompt
        
        # Add any other parameters
        for k, v in kwargs.items():
            if k not in ["prompt", "model", "negative_prompt"]:
                params[k] = v
        
        try:
            # Call the API with all parameters as kwargs
            image = self.client.text_to_image(**params)
            return image
        except Exception as e:
            print(f"Error generating image: {e}")
            print(f"Model: {model}")
            print(f"Prompt: {prompt}")
            raise

    def image_to_image(self, image, prompt=None, model_name=None, negative_prompt=None, **kwargs):
        """
        Generate a new image from an input image and optional prompt.
        
        Args:
            image (PIL.Image or str): Input image or path to image
            prompt (str, optional): Text prompt to guide the transformation
            model_name (str, optional): The model to use for inference
            negative_prompt (str, optional): What not to include in the image
            **kwargs: Additional parameters to pass to the model
            
        Returns:
            PIL.Image: The generated image
        """
        import tempfile
        import os
        
        model = model_name or config.DEFAULT_IMG2IMG_MODEL
        
        # Create a temporary file for the image if it's a PIL Image
        temp_file = None
        try:
            # Handle different image input types
            if isinstance(image, str):
                # If it's already a file path, use it directly
                image_path = image
            elif isinstance(image, Image.Image):
                # If it's a PIL Image, save it to a temporary file
                temp_dir = tempfile.gettempdir()
                temp_file = os.path.join(temp_dir, "temp_image.png")
                image.save(temp_file, format="PNG")
                image_path = temp_file
            else:
                # If it's something else, try to convert it to a PIL Image first
                try:
                    pil_image = Image.fromarray(image)
                    temp_dir = tempfile.gettempdir()
                    temp_file = os.path.join(temp_dir, "temp_image.png")
                    pil_image.save(temp_file, format="PNG")
                    image_path = temp_file
                except Exception as e:
                    raise ValueError(f"Unsupported image type: {type(image)}. Error: {e}")
            
            # Create a NEW InferenceClient for this call to avoid any potential state issues
            client = InferenceClient(
                provider="hf-inference",
                api_key=self.api_key,
            )
            
            # Create the parameter dict with only the non-None values
            params = {}
            # Only add parameters that are not None
            if model is not None:
                params["model"] = model
            if prompt is not None:
                params["prompt"] = prompt
            if negative_prompt is not None:
                params["negative_prompt"] = negative_prompt
            
            # Add additional kwargs, but filter out any that might create conflicts
            for k, v in kwargs.items():
                if v is not None and k not in ["image", "prompt", "model", "negative_prompt"]:
                    params[k] = v
                    
            # Debug the parameters we're sending
            print(f"DEBUG: Calling image_to_image with:")
            print(f"- Image path: {image_path}")
            print(f"- Parameters: {params}")
            
            # Make the API call
            result = client.image_to_image(image_path, **params)
            return result
            
        except Exception as e:
            print(f"Error transforming image: {e}")
            print(f"Image type: {type(image)}")
            print(f"Model: {model}")
            print(f"Prompt: {prompt}")
            raise
            
        finally:
            # Clean up the temporary file if it was created
            if temp_file and os.path.exists(temp_file):
                try:
                    os.remove(temp_file)
                except Exception as e:
                    print(f"Warning: Could not delete temporary file {temp_file}: {e}")