Sam3838 commited on
Commit
f06b6e7
·
verified ·
1 Parent(s): 12eff42

Update image_generator.py

Browse files
Files changed (1) hide show
  1. image_generator.py +178 -100
image_generator.py CHANGED
@@ -1,100 +1,178 @@
1
- import base64
2
- import io
3
- import os
4
- import time
5
- import logging
6
- from typing import List
7
- from PIL import Image
8
- from huggingface_hub import InferenceClient
9
- from config import config
10
- from models import ImageGenerationRequest, ImageData, ResponseFormat
11
-
12
- logger = logging.getLogger(__name__)
13
-
14
-
15
- class ImageGenerator:
16
- """Text-to-image generator using Hugging Face InferenceClient"""
17
-
18
- def __init__(self):
19
- self.client = None
20
- self._ensure_output_dir()
21
-
22
- def _ensure_output_dir(self):
23
- """Ensure output directory exists"""
24
- os.makedirs(config.OUTPUT_DIR, exist_ok=True)
25
-
26
- def _get_client(self):
27
- """Get or create the InferenceClient"""
28
- if self.client is None:
29
- self.client = InferenceClient(
30
- provider="replicate",
31
- api_key=config.HF_TOKEN,
32
- )
33
- return self.client
34
-
35
- def _image_to_base64(self, image: Image.Image) -> str:
36
- """Convert PIL Image to base64 string"""
37
- buffer = io.BytesIO()
38
- image.save(buffer, format="PNG")
39
- img_str = base64.b64encode(buffer.getvalue()).decode()
40
- return img_str
41
-
42
- def _save_image(self, image: Image.Image, filename: str) -> str:
43
- """Save image and return URL"""
44
- filepath = os.path.join(config.OUTPUT_DIR, filename)
45
- image.save(filepath)
46
- return f"{config.BASE_URL}/images/{filename}"
47
-
48
- async def generate_images(self, request: ImageGenerationRequest) -> List[ImageData]:
49
- """Generate images based on the request"""
50
- client = self._get_client()
51
-
52
- # Generate images
53
- results = []
54
-
55
- for i in range(request.n):
56
- try:
57
- logger.info(f"Generating image {i+1}/{request.n} for prompt: {request.prompt[:50]}...")
58
-
59
- # Generate the image using HuggingFace InferenceClient
60
- image = client.text_to_image(
61
- request.prompt,
62
- model=config.DEFAULT_MODEL,
63
- )
64
-
65
- # Create response based on format
66
- if request.response_format == ResponseFormat.B64_JSON:
67
- image_data = ImageData(
68
- b64_json=self._image_to_base64(image),
69
- revised_prompt=request.prompt
70
- )
71
- else:
72
- # Save image and return URL
73
- timestamp = int(time.time())
74
- filename = f"generated_{timestamp}_{i}.png"
75
- url = self._save_image(image, filename)
76
- image_data = ImageData(
77
- url=url,
78
- revised_prompt=request.prompt
79
- )
80
-
81
- results.append(image_data)
82
- logger.info(f"Successfully generated image {i+1}/{request.n}")
83
-
84
- except Exception as e:
85
- logger.error(f"Failed to generate image {i+1}: {e}")
86
- # Continue with other images
87
- continue
88
-
89
- if not results:
90
- raise Exception("Failed to generate any images")
91
-
92
- return results
93
-
94
- def cleanup(self):
95
- """Cleanup resources"""
96
- self.client = None
97
-
98
-
99
- # Global instance
100
- image_generator = ImageGenerator()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import io
3
+ import os
4
+ import sys
5
+ import time
6
+ import logging
7
+ import tempfile
8
+ import subprocess
9
+ from typing import List
10
+ from enum import Enum
11
+
12
+ # Install required packages
13
+ def install_packages():
14
+ """Install required packages using pip"""
15
+ packages = [
16
+ "pillow",
17
+ "huggingface_hub",
18
+ "pydantic"
19
+ ]
20
+
21
+ for package in packages:
22
+ try:
23
+ __import__(package.replace("-", "_"))
24
+ print(f"{package} already installed")
25
+ except ImportError:
26
+ print(f"Installing {package}...")
27
+ subprocess.check_call([sys.executable, "-m", "pip", "install", package])
28
+
29
+ # Install packages before importing
30
+ install_packages()
31
+
32
+ from PIL import Image
33
+ from huggingface_hub import InferenceClient
34
+ from pydantic import BaseModel
35
+
36
+ logger = logging.getLogger(__name__)
37
+
38
+ # Define models directly in the file
39
+ class ResponseFormat(str, Enum):
40
+ URL = "url"
41
+ B64_JSON = "b64_json"
42
+
43
+ class ImageGenerationRequest(BaseModel):
44
+ prompt: str
45
+ model: str = "black-forest-labs/flux-schnell"
46
+ n: int = 1
47
+ size: str = "1024x1024"
48
+ quality: str = "standard"
49
+ response_format: ResponseFormat = ResponseFormat.URL
50
+
51
+ class ImageData(BaseModel):
52
+ url: str = None
53
+ b64_json: str = None
54
+ revised_prompt: str = None
55
+
56
+ class ImageGenerator:
57
+ """Text-to-image generator using Hugging Face InferenceClient"""
58
+
59
+ def __init__(self, hf_token: str = None):
60
+ self.client = None
61
+ self.hf_token = hf_token or os.getenv("HF_TOKEN")
62
+ self.output_dir = tempfile.mkdtemp(prefix="image_gen_")
63
+ self.base_url = "http://localhost:8000" # Default base URL
64
+ self.default_model = "black-forest-labs/flux-schnell"
65
+ self._ensure_output_dir()
66
+
67
+ def _ensure_output_dir(self):
68
+ """Ensure output directory exists"""
69
+ os.makedirs(self.output_dir, exist_ok=True)
70
+ print(f"Using temporary directory: {self.output_dir}")
71
+
72
+ def _get_client(self):
73
+ """Get or create the InferenceClient"""
74
+ if self.client is None:
75
+ if not self.hf_token:
76
+ raise ValueError("HuggingFace token is required. Set HF_TOKEN environment variable or pass it to constructor.")
77
+
78
+ self.client = InferenceClient(
79
+ token=self.hf_token,
80
+ )
81
+ return self.client
82
+
83
+ def _image_to_base64(self, image: Image.Image) -> str:
84
+ """Convert PIL Image to base64 string"""
85
+ buffer = io.BytesIO()
86
+ image.save(buffer, format="PNG")
87
+ img_str = base64.b64encode(buffer.getvalue()).decode()
88
+ return img_str
89
+
90
+ def _save_image(self, image: Image.Image, filename: str) -> str:
91
+ """Save image and return URL"""
92
+ filepath = os.path.join(self.output_dir, filename)
93
+ image.save(filepath)
94
+ return f"{self.base_url}/images/{filename}"
95
+
96
+ def set_config(self, hf_token: str = None, base_url: str = None, default_model: str = None):
97
+ """Set configuration parameters"""
98
+ if hf_token:
99
+ self.hf_token = hf_token
100
+ self.client = None # Reset client to use new token
101
+ if base_url:
102
+ self.base_url = base_url
103
+ if default_model:
104
+ self.default_model = default_model
105
+
106
+ async def generate_images(self, request: ImageGenerationRequest) -> List[ImageData]:
107
+ """Generate images based on the request"""
108
+ client = self._get_client()
109
+
110
+ # Generate images
111
+ results = []
112
+
113
+ for i in range(request.n):
114
+ try:
115
+ logger.info(f"Generating image {i+1}/{request.n} for prompt: {request.prompt[:50]}...")
116
+
117
+ # Generate the image using HuggingFace InferenceClient
118
+ image = client.text_to_image(
119
+ request.prompt,
120
+ model=request.model or self.default_model,
121
+ )
122
+
123
+ # Create response based on format
124
+ if request.response_format == ResponseFormat.B64_JSON:
125
+ image_data = ImageData(
126
+ b64_json=self._image_to_base64(image),
127
+ revised_prompt=request.prompt
128
+ )
129
+ else:
130
+ # Save image and return URL
131
+ timestamp = int(time.time())
132
+ filename = f"generated_{timestamp}_{i}.png"
133
+ url = self._save_image(image, filename)
134
+ image_data = ImageData(
135
+ url=url,
136
+ revised_prompt=request.prompt
137
+ )
138
+
139
+ results.append(image_data)
140
+ logger.info(f"Successfully generated image {i+1}/{request.n}")
141
+
142
+ except Exception as e:
143
+ logger.error(f"Failed to generate image {i+1}: {e}")
144
+ # Continue with other images
145
+ continue
146
+
147
+ if not results:
148
+ raise Exception("Failed to generate any images")
149
+
150
+ return results
151
+
152
+ def cleanup(self):
153
+ """Cleanup resources and temporary directory"""
154
+ self.client = None
155
+ # Clean up temporary directory
156
+ import shutil
157
+ if os.path.exists(self.output_dir):
158
+ shutil.rmtree(self.output_dir)
159
+ print(f"Cleaned up temporary directory: {self.output_dir}")
160
+
161
+ # Example usage
162
+ if __name__ == "__main__":
163
+ # Create generator instance
164
+ generator = ImageGenerator()
165
+
166
+ # Set HuggingFace token (replace with your actual token)
167
+ generator.set_config(hf_token="your_hf_token_here")
168
+
169
+ # Example request
170
+ request = ImageGenerationRequest(
171
+ prompt="A beautiful sunset over mountains",
172
+ n=1,
173
+ response_format=ResponseFormat.URL
174
+ )
175
+
176
+ # Note: This would need to be run in an async context
177
+ # results = await generator.generate_images(request)
178
+ print("Image generator setup complete!")