File size: 7,475 Bytes
f7f6fe3
 
ae9e00c
f7f6fe3
 
 
ae9e00c
f7f6fe3
ae9e00c
f7f6fe3
ae9e00c
 
f7f6fe3
 
 
ae9e00c
f7f6fe3
ae9e00c
f7f6fe3
ae9e00c
 
 
 
 
 
 
 
f7f6fe3
ae9e00c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f7f6fe3
 
ae9e00c
 
f7f6fe3
ae9e00c
 
f7f6fe3
ae9e00c
f7f6fe3
ae9e00c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f7f6fe3
ae9e00c
 
f7f6fe3
ae9e00c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# storyverse_weaver/core/image_services.py
import os
import requests
import base64
from io import BytesIO
from PIL import Image
from huggingface_hub import InferenceClient # Ensure this is imported

# --- API Key Configuration ---
STABILITY_API_KEY = os.getenv("STORYVERSE_STABILITY_API_KEY")
OPENAI_API_KEY = os.getenv("STORYVERSE_OPENAI_API_KEY")
HF_TOKEN = os.getenv("STORYVERSE_HF_TOKEN") # Reuse from llm_services or get it here

STABILITY_API_CONFIGURED = bool(STABILITY_API_KEY and STABILITY_API_KEY.strip())
OPENAI_DALLE_CONFIGURED = bool(OPENAI_API_KEY and OPENAI_API_KEY.strip())
HF_IMAGE_API_CONFIGURED = bool(HF_TOKEN and HF_TOKEN.strip()) # New flag

hf_inference_image_client = None # Separate client instance for image tasks if needed, or reuse

class ImageGenResponse: # Keep this class
    def __init__(self, image: Image.Image = None, image_url: str = None, error: str = None, success: bool = True, provider: str = "unknown", model_id_used: str = None):
        self.image, self.image_url, self.error, self.success, self.provider, self.model_id_used = \
            image, image_url, error, success, provider, model_id_used

def initialize_image_llms():
    global STABILITY_API_CONFIGURED, OPENAI_DALLE_CONFIGURED, HF_IMAGE_API_CONFIGURED, hf_inference_image_client
    
    print("INFO: image_services.py - Initializing Image Generation services...")
    # Stability AI (as before)
    if STABILITY_API_KEY and STABILITY_API_KEY.strip():
        STABILITY_API_CONFIGURED = True
        print("SUCCESS: image_services.py - Stability AI API Key detected.")
    else:
        STABILITY_API_CONFIGURED = False
        print("WARNING: image_services.py - STORYVERSE_STABILITY_API_KEY not found. Stability AI disabled.")

    # OpenAI DALL-E (as before)
    if OPENAI_API_KEY and OPENAI_API_KEY.strip():
        OPENAI_DALLE_CONFIGURED = True
        print("SUCCESS: image_services.py - OpenAI API Key detected (for DALL-E).")
    else:
        OPENAI_DALLE_CONFIGURED = False
        print("WARNING: image_services.py - STORYVERSE_OPENAI_API_KEY not found. DALL-E disabled.")

    # Hugging Face Image Models
    if HF_TOKEN and HF_TOKEN.strip():
        try:
            # You can use the same token for text and image clients if the permissions cover it
            # Or, if you want a dedicated client for image tasks (maybe different default model types)
            hf_inference_image_client = InferenceClient(token=HF_TOKEN) # Reusing the token
            HF_IMAGE_API_CONFIGURED = True
            print("SUCCESS: image_services.py - Hugging Face InferenceClient (for images) ready.")
        except Exception as e:
            HF_IMAGE_API_CONFIGURED = False
            print(f"ERROR: image_services.py - Failed to initialize HF InferenceClient for images: {e}")
    else:
        HF_IMAGE_API_CONFIGURED = False
        print("WARNING: image_services.py - STORYVERSE_HF_TOKEN not found. HF Image models disabled.")
    
    print("INFO: image_services.py - Image Service Init complete.")


# --- Stability AI (Keep as is, it will just be disabled if no key) ---
def generate_image_stabilityai(prompt: str, ...) -> ImageGenResponse:
    # ... (your existing generate_image_stabilityai function)
    if not STABILITY_API_CONFIGURED: return ImageGenResponse(error="Stability AI API key not configured.", success=False, provider="StabilityAI")
    # ... rest of the function
    api_host = os.getenv('API_HOST', 'https://api.stability.ai'); request_url = f"{api_host}/v1/generation/{engine_id if 'engine_id' in locals() else 'stable-diffusion-xl-1024-v1-0'}/text-to-image"; payload = {"text_prompts": [{"text": prompt}], "steps": steps if 'steps' in locals() else 30}; headers = {"Authorization": f"Bearer {STABILITY_API_KEY}", "Accept":"application/json", "Content-Type":"application/json"}; try: response = requests.post(request_url, headers=headers, json=payload, timeout=60); response.raise_for_status(); artifacts = response.json().get("artifacts"); img_data = base64.b64decode(artifacts[0]["base64"]); img = Image.open(BytesIO(img_data)); return ImageGenResponse(image=img, provider="StabilityAI")
    except Exception as e: return ImageGenResponse(error=f"Stability AI Error: {str(e)}", success=False, provider="StabilityAI", raw_response=e)


# --- DALL-E (Keep as is, conceptual) ---
def generate_image_dalle(prompt: str, ...) -> ImageGenResponse:
    # ... (your existing generate_image_dalle function)
    if not OPENAI_DALLE_CONFIGURED: return ImageGenResponse(error="OpenAI DALL-E API key not configured.", success=False, provider="DALL-E")
    dummy_image = Image.new('RGB', (512, 512), color = 'skyblue'); return ImageGenResponse(image=dummy_image, provider="DALL-E (Simulated)")


# --- NEW: Hugging Face Image Model via Inference API ---
def generate_image_hf_model(prompt: str, 
                            model_id: str = "stabilityai/stable-diffusion-xl-base-1.0", # A popular choice
                            # model_id: str = "runwayml/stable-diffusion-v1-5", # Another option
                            # model_id: str = "prompthero/openjourney", # Midjourney-like style
                            negative_prompt: str = None,
                            height: int = 768, # Adjust for different models
                            width: int = 768,
                            num_inference_steps: int = 25,
                            guidance_scale: float = 7.5
                            ) -> ImageGenResponse:
    global hf_inference_image_client # Use the initialized client
    if not HF_IMAGE_API_CONFIGURED or not hf_inference_image_client:
        return ImageGenResponse(error="Hugging Face API (for images) not configured.", success=False, provider="HF Image API", model_id_used=model_id)

    params = {
        "negative_prompt": negative_prompt,
        "height": height,
        "width": width,
        "num_inference_steps": num_inference_steps,
        "guidance_scale": guidance_scale
    }
    # Remove None params as some models might not like them
    params = {k: v for k, v in params.items() if v is not None}

    print(f"DEBUG: image_services.py - Calling HF Image API ({model_id}) with prompt: {prompt[:50]}...")
    try:
        # The `text_to_image` method of InferenceClient returns a PIL Image directly
        image_result: Image.Image = hf_inference_image_client.text_to_image(
            prompt,
            model=model_id,
            **params 
        )
        print(f"DEBUG: image_services.py - HF Image API ({model_id}) image generated successfully.")
        return ImageGenResponse(image=image_result, provider="HF Image API", model_id_used=model_id)
    except Exception as e:
        error_msg = f"HF Image API Error ({model_id}): {type(e).__name__} - {str(e)}"
        # Check for common HF API errors
        if "Rate limit reached" in str(e):
            error_msg += " You may have hit free tier limits."
        elif "Model is currently loading" in str(e):
            error_msg += " The model might be loading, please try again in a moment."
        elif "Authorization header is correct" in str(e) or "401" in str(e):
             error_msg += " Issue with your HF_TOKEN authentication."

        print(f"ERROR: image_services.py - {error_msg}")
        return ImageGenResponse(error=error_msg, success=False, provider="HF Image API", model_id_used=model_id, raw_response=e)

print("DEBUG: core.image_services (for StoryVerseWeaver) - Module defined with HF Image support.")