TTI / image_generator.py
Sam3838's picture
Update image_generator.py
f06b6e7 verified
import base64
import io
import os
import sys
import time
import logging
import tempfile
import subprocess
from typing import List
from enum import Enum
# Install required packages
def install_packages():
"""Install required packages using pip"""
packages = [
"pillow",
"huggingface_hub",
"pydantic"
]
for package in packages:
try:
__import__(package.replace("-", "_"))
print(f"{package} already installed")
except ImportError:
print(f"Installing {package}...")
subprocess.check_call([sys.executable, "-m", "pip", "install", package])
# Install packages before importing
install_packages()
from PIL import Image
from huggingface_hub import InferenceClient
from pydantic import BaseModel
logger = logging.getLogger(__name__)
# Define models directly in the file
class ResponseFormat(str, Enum):
URL = "url"
B64_JSON = "b64_json"
class ImageGenerationRequest(BaseModel):
prompt: str
model: str = "black-forest-labs/flux-schnell"
n: int = 1
size: str = "1024x1024"
quality: str = "standard"
response_format: ResponseFormat = ResponseFormat.URL
class ImageData(BaseModel):
url: str = None
b64_json: str = None
revised_prompt: str = None
class ImageGenerator:
"""Text-to-image generator using Hugging Face InferenceClient"""
def __init__(self, hf_token: str = None):
self.client = None
self.hf_token = hf_token or os.getenv("HF_TOKEN")
self.output_dir = tempfile.mkdtemp(prefix="image_gen_")
self.base_url = "http://localhost:8000" # Default base URL
self.default_model = "black-forest-labs/flux-schnell"
self._ensure_output_dir()
def _ensure_output_dir(self):
"""Ensure output directory exists"""
os.makedirs(self.output_dir, exist_ok=True)
print(f"Using temporary directory: {self.output_dir}")
def _get_client(self):
"""Get or create the InferenceClient"""
if self.client is None:
if not self.hf_token:
raise ValueError("HuggingFace token is required. Set HF_TOKEN environment variable or pass it to constructor.")
self.client = InferenceClient(
token=self.hf_token,
)
return self.client
def _image_to_base64(self, image: Image.Image) -> str:
"""Convert PIL Image to base64 string"""
buffer = io.BytesIO()
image.save(buffer, format="PNG")
img_str = base64.b64encode(buffer.getvalue()).decode()
return img_str
def _save_image(self, image: Image.Image, filename: str) -> str:
"""Save image and return URL"""
filepath = os.path.join(self.output_dir, filename)
image.save(filepath)
return f"{self.base_url}/images/{filename}"
def set_config(self, hf_token: str = None, base_url: str = None, default_model: str = None):
"""Set configuration parameters"""
if hf_token:
self.hf_token = hf_token
self.client = None # Reset client to use new token
if base_url:
self.base_url = base_url
if default_model:
self.default_model = default_model
async def generate_images(self, request: ImageGenerationRequest) -> List[ImageData]:
"""Generate images based on the request"""
client = self._get_client()
# Generate images
results = []
for i in range(request.n):
try:
logger.info(f"Generating image {i+1}/{request.n} for prompt: {request.prompt[:50]}...")
# Generate the image using HuggingFace InferenceClient
image = client.text_to_image(
request.prompt,
model=request.model or self.default_model,
)
# Create response based on format
if request.response_format == ResponseFormat.B64_JSON:
image_data = ImageData(
b64_json=self._image_to_base64(image),
revised_prompt=request.prompt
)
else:
# Save image and return URL
timestamp = int(time.time())
filename = f"generated_{timestamp}_{i}.png"
url = self._save_image(image, filename)
image_data = ImageData(
url=url,
revised_prompt=request.prompt
)
results.append(image_data)
logger.info(f"Successfully generated image {i+1}/{request.n}")
except Exception as e:
logger.error(f"Failed to generate image {i+1}: {e}")
# Continue with other images
continue
if not results:
raise Exception("Failed to generate any images")
return results
def cleanup(self):
"""Cleanup resources and temporary directory"""
self.client = None
# Clean up temporary directory
import shutil
if os.path.exists(self.output_dir):
shutil.rmtree(self.output_dir)
print(f"Cleaned up temporary directory: {self.output_dir}")
# Example usage
if __name__ == "__main__":
# Create generator instance
generator = ImageGenerator()
# Set HuggingFace token (replace with your actual token)
generator.set_config(hf_token="your_hf_token_here")
# Example request
request = ImageGenerationRequest(
prompt="A beautiful sunset over mountains",
n=1,
response_format=ResponseFormat.URL
)
# Note: This would need to be run in an async context
# results = await generator.generate_images(request)
print("Image generator setup complete!")