Spaces:
Sleeping
Sleeping
# storyverse_weaver/core/image_services.py | |
import os | |
import requests | |
import base64 | |
from io import BytesIO | |
from PIL import Image | |
from huggingface_hub import InferenceClient # Ensure this is imported | |
# --- API Key Configuration --- | |
STABILITY_API_KEY = os.getenv("STORYVERSE_STABILITY_API_KEY") | |
OPENAI_API_KEY = os.getenv("STORYVERSE_OPENAI_API_KEY") | |
HF_TOKEN = os.getenv("STORYVERSE_HF_TOKEN") # Reuse from llm_services or get it here | |
STABILITY_API_CONFIGURED = bool(STABILITY_API_KEY and STABILITY_API_KEY.strip()) | |
OPENAI_DALLE_CONFIGURED = bool(OPENAI_API_KEY and OPENAI_API_KEY.strip()) | |
HF_IMAGE_API_CONFIGURED = bool(HF_TOKEN and HF_TOKEN.strip()) # New flag | |
hf_inference_image_client = None # Separate client instance for image tasks if needed, or reuse | |
class ImageGenResponse: # Keep this class | |
def __init__(self, image: Image.Image = None, image_url: str = None, error: str = None, success: bool = True, provider: str = "unknown", model_id_used: str = None): | |
self.image, self.image_url, self.error, self.success, self.provider, self.model_id_used = \ | |
image, image_url, error, success, provider, model_id_used | |
def initialize_image_llms(): | |
global STABILITY_API_CONFIGURED, OPENAI_DALLE_CONFIGURED, HF_IMAGE_API_CONFIGURED, hf_inference_image_client | |
print("INFO: image_services.py - Initializing Image Generation services...") | |
# Stability AI (as before) | |
if STABILITY_API_KEY and STABILITY_API_KEY.strip(): | |
STABILITY_API_CONFIGURED = True | |
print("SUCCESS: image_services.py - Stability AI API Key detected.") | |
else: | |
STABILITY_API_CONFIGURED = False | |
print("WARNING: image_services.py - STORYVERSE_STABILITY_API_KEY not found. Stability AI disabled.") | |
# OpenAI DALL-E (as before) | |
if OPENAI_API_KEY and OPENAI_API_KEY.strip(): | |
OPENAI_DALLE_CONFIGURED = True | |
print("SUCCESS: image_services.py - OpenAI API Key detected (for DALL-E).") | |
else: | |
OPENAI_DALLE_CONFIGURED = False | |
print("WARNING: image_services.py - STORYVERSE_OPENAI_API_KEY not found. DALL-E disabled.") | |
# Hugging Face Image Models | |
if HF_TOKEN and HF_TOKEN.strip(): | |
try: | |
# You can use the same token for text and image clients if the permissions cover it | |
# Or, if you want a dedicated client for image tasks (maybe different default model types) | |
hf_inference_image_client = InferenceClient(token=HF_TOKEN) # Reusing the token | |
HF_IMAGE_API_CONFIGURED = True | |
print("SUCCESS: image_services.py - Hugging Face InferenceClient (for images) ready.") | |
except Exception as e: | |
HF_IMAGE_API_CONFIGURED = False | |
print(f"ERROR: image_services.py - Failed to initialize HF InferenceClient for images: {e}") | |
else: | |
HF_IMAGE_API_CONFIGURED = False | |
print("WARNING: image_services.py - STORYVERSE_HF_TOKEN not found. HF Image models disabled.") | |
print("INFO: image_services.py - Image Service Init complete.") | |
# --- Stability AI (Keep as is, it will just be disabled if no key) --- | |
def generate_image_stabilityai(prompt: str, ...) -> ImageGenResponse: | |
# ... (your existing generate_image_stabilityai function) | |
if not STABILITY_API_CONFIGURED: return ImageGenResponse(error="Stability AI API key not configured.", success=False, provider="StabilityAI") | |
# ... rest of the function | |
api_host = os.getenv('API_HOST', 'https://api.stability.ai'); request_url = f"{api_host}/v1/generation/{engine_id if 'engine_id' in locals() else 'stable-diffusion-xl-1024-v1-0'}/text-to-image"; payload = {"text_prompts": [{"text": prompt}], "steps": steps if 'steps' in locals() else 30}; headers = {"Authorization": f"Bearer {STABILITY_API_KEY}", "Accept":"application/json", "Content-Type":"application/json"}; try: response = requests.post(request_url, headers=headers, json=payload, timeout=60); response.raise_for_status(); artifacts = response.json().get("artifacts"); img_data = base64.b64decode(artifacts[0]["base64"]); img = Image.open(BytesIO(img_data)); return ImageGenResponse(image=img, provider="StabilityAI") | |
except Exception as e: return ImageGenResponse(error=f"Stability AI Error: {str(e)}", success=False, provider="StabilityAI", raw_response=e) | |
# --- DALL-E (Keep as is, conceptual) --- | |
def generate_image_dalle(prompt: str, ...) -> ImageGenResponse: | |
# ... (your existing generate_image_dalle function) | |
if not OPENAI_DALLE_CONFIGURED: return ImageGenResponse(error="OpenAI DALL-E API key not configured.", success=False, provider="DALL-E") | |
dummy_image = Image.new('RGB', (512, 512), color = 'skyblue'); return ImageGenResponse(image=dummy_image, provider="DALL-E (Simulated)") | |
# --- NEW: Hugging Face Image Model via Inference API --- | |
def generate_image_hf_model(prompt: str, | |
model_id: str = "stabilityai/stable-diffusion-xl-base-1.0", # A popular choice | |
# model_id: str = "runwayml/stable-diffusion-v1-5", # Another option | |
# model_id: str = "prompthero/openjourney", # Midjourney-like style | |
negative_prompt: str = None, | |
height: int = 768, # Adjust for different models | |
width: int = 768, | |
num_inference_steps: int = 25, | |
guidance_scale: float = 7.5 | |
) -> ImageGenResponse: | |
global hf_inference_image_client # Use the initialized client | |
if not HF_IMAGE_API_CONFIGURED or not hf_inference_image_client: | |
return ImageGenResponse(error="Hugging Face API (for images) not configured.", success=False, provider="HF Image API", model_id_used=model_id) | |
params = { | |
"negative_prompt": negative_prompt, | |
"height": height, | |
"width": width, | |
"num_inference_steps": num_inference_steps, | |
"guidance_scale": guidance_scale | |
} | |
# Remove None params as some models might not like them | |
params = {k: v for k, v in params.items() if v is not None} | |
print(f"DEBUG: image_services.py - Calling HF Image API ({model_id}) with prompt: {prompt[:50]}...") | |
try: | |
# The `text_to_image` method of InferenceClient returns a PIL Image directly | |
image_result: Image.Image = hf_inference_image_client.text_to_image( | |
prompt, | |
model=model_id, | |
**params | |
) | |
print(f"DEBUG: image_services.py - HF Image API ({model_id}) image generated successfully.") | |
return ImageGenResponse(image=image_result, provider="HF Image API", model_id_used=model_id) | |
except Exception as e: | |
error_msg = f"HF Image API Error ({model_id}): {type(e).__name__} - {str(e)}" | |
# Check for common HF API errors | |
if "Rate limit reached" in str(e): | |
error_msg += " You may have hit free tier limits." | |
elif "Model is currently loading" in str(e): | |
error_msg += " The model might be loading, please try again in a moment." | |
elif "Authorization header is correct" in str(e) or "401" in str(e): | |
error_msg += " Issue with your HF_TOKEN authentication." | |
print(f"ERROR: image_services.py - {error_msg}") | |
return ImageGenResponse(error=error_msg, success=False, provider="HF Image API", model_id_used=model_id, raw_response=e) | |
print("DEBUG: core.image_services (for StoryVerseWeaver) - Module defined with HF Image support.") |