mgbam commited on
Commit
2e5de57
·
verified ·
1 Parent(s): faf8e43

Update core/image_services.py

Browse files
Files changed (1) hide show
  1. core/image_services.py +45 -79
core/image_services.py CHANGED
@@ -1,95 +1,63 @@
1
  # storyverse_weaver/core/image_services.py
2
  import os
3
- import requests
4
- import base64
5
  from io import BytesIO
6
  from PIL import Image
7
- from huggingface_hub import InferenceClient # Ensure this is imported
8
 
9
  # --- API Key Configuration ---
10
- STABILITY_API_KEY = os.getenv("STORYVERSE_STABILITY_API_KEY")
11
- OPENAI_API_KEY = os.getenv("STORYVERSE_OPENAI_API_KEY")
12
- HF_TOKEN = os.getenv("STORYVERSE_HF_TOKEN") # Reuse from llm_services or get it here
13
 
14
- STABILITY_API_CONFIGURED = bool(STABILITY_API_KEY and STABILITY_API_KEY.strip())
15
- OPENAI_DALLE_CONFIGURED = bool(OPENAI_API_KEY and OPENAI_API_KEY.strip())
16
- HF_IMAGE_API_CONFIGURED = bool(HF_TOKEN and HF_TOKEN.strip()) # New flag
17
 
18
- hf_inference_image_client = None # Separate client instance for image tasks if needed, or reuse
19
-
20
- class ImageGenResponse: # Keep this class
21
- def __init__(self, image: Image.Image = None, image_url: str = None, error: str = None, success: bool = True, provider: str = "unknown", model_id_used: str = None):
22
  self.image, self.image_url, self.error, self.success, self.provider, self.model_id_used = \
23
  image, image_url, error, success, provider, model_id_used
24
 
25
- def initialize_image_llms():
26
- global STABILITY_API_CONFIGURED, OPENAI_DALLE_CONFIGURED, HF_IMAGE_API_CONFIGURED, hf_inference_image_client
 
 
 
27
 
28
- print("INFO: image_services.py - Initializing Image Generation services...")
29
- # Stability AI (as before)
30
- if STABILITY_API_KEY and STABILITY_API_KEY.strip():
31
- STABILITY_API_CONFIGURED = True
32
- print("SUCCESS: image_services.py - Stability AI API Key detected.")
33
- else:
34
- STABILITY_API_CONFIGURED = False
35
- print("WARNING: image_services.py - STORYVERSE_STABILITY_API_KEY not found. Stability AI disabled.")
36
-
37
- # OpenAI DALL-E (as before)
38
- if OPENAI_API_KEY and OPENAI_API_KEY.strip():
39
- OPENAI_DALLE_CONFIGURED = True
40
- print("SUCCESS: image_services.py - OpenAI API Key detected (for DALL-E).")
41
- else:
42
- OPENAI_DALLE_CONFIGURED = False
43
- print("WARNING: image_services.py - STORYVERSE_OPENAI_API_KEY not found. DALL-E disabled.")
44
-
45
- # Hugging Face Image Models
46
  if HF_TOKEN and HF_TOKEN.strip():
47
  try:
48
- # You can use the same token for text and image clients if the permissions cover it
49
- # Or, if you want a dedicated client for image tasks (maybe different default model types)
50
- hf_inference_image_client = InferenceClient(token=HF_TOKEN) # Reusing the token
 
51
  HF_IMAGE_API_CONFIGURED = True
52
  print("SUCCESS: image_services.py - Hugging Face InferenceClient (for images) ready.")
53
  except Exception as e:
54
  HF_IMAGE_API_CONFIGURED = False
55
- print(f"ERROR: image_services.py - Failed to initialize HF InferenceClient for images: {e}")
 
56
  else:
57
  HF_IMAGE_API_CONFIGURED = False
58
- print("WARNING: image_services.py - STORYVERSE_HF_TOKEN not found. HF Image models disabled.")
59
 
60
- print("INFO: image_services.py - Image Service Init complete.")
61
 
 
 
 
62
 
63
- # --- Stability AI (Keep as is, it will just be disabled if no key) ---
64
- def generate_image_stabilityai(prompt: str, ...) -> ImageGenResponse:
65
- # ... (your existing generate_image_stabilityai function)
66
- if not STABILITY_API_CONFIGURED: return ImageGenResponse(error="Stability AI API key not configured.", success=False, provider="StabilityAI")
67
- # ... rest of the function
68
- api_host = os.getenv('API_HOST', 'https://api.stability.ai'); request_url = f"{api_host}/v1/generation/{engine_id if 'engine_id' in locals() else 'stable-diffusion-xl-1024-v1-0'}/text-to-image"; payload = {"text_prompts": [{"text": prompt}], "steps": steps if 'steps' in locals() else 30}; headers = {"Authorization": f"Bearer {STABILITY_API_KEY}", "Accept":"application/json", "Content-Type":"application/json"}; try: response = requests.post(request_url, headers=headers, json=payload, timeout=60); response.raise_for_status(); artifacts = response.json().get("artifacts"); img_data = base64.b64decode(artifacts[0]["base64"]); img = Image.open(BytesIO(img_data)); return ImageGenResponse(image=img, provider="StabilityAI")
69
- except Exception as e: return ImageGenResponse(error=f"Stability AI Error: {str(e)}", success=False, provider="StabilityAI", raw_response=e)
70
-
71
-
72
- # --- DALL-E (Keep as is, conceptual) ---
73
- def generate_image_dalle(prompt: str, ...) -> ImageGenResponse:
74
- # ... (your existing generate_image_dalle function)
75
- if not OPENAI_DALLE_CONFIGURED: return ImageGenResponse(error="OpenAI DALL-E API key not configured.", success=False, provider="DALL-E")
76
- dummy_image = Image.new('RGB', (512, 512), color = 'skyblue'); return ImageGenResponse(image=dummy_image, provider="DALL-E (Simulated)")
77
-
78
-
79
- # --- NEW: Hugging Face Image Model via Inference API ---
80
  def generate_image_hf_model(prompt: str,
81
- model_id: str = "stabilityai/stable-diffusion-xl-base-1.0", # A popular choice
82
- # model_id: str = "runwayml/stable-diffusion-v1-5", # Another option
83
- # model_id: str = "prompthero/openjourney", # Midjourney-like style
84
  negative_prompt: str = None,
85
- height: int = 768, # Adjust for different models
86
- width: int = 768,
87
  num_inference_steps: int = 25,
88
- guidance_scale: float = 7.5
89
- ) -> ImageGenResponse:
90
- global hf_inference_image_client # Use the initialized client
91
- if not HF_IMAGE_API_CONFIGURED or not hf_inference_image_client:
92
- return ImageGenResponse(error="Hugging Face API (for images) not configured.", success=False, provider="HF Image API", model_id_used=model_id)
93
 
94
  params = {
95
  "negative_prompt": negative_prompt,
@@ -98,30 +66,28 @@ def generate_image_hf_model(prompt: str,
98
  "num_inference_steps": num_inference_steps,
99
  "guidance_scale": guidance_scale
100
  }
101
- # Remove None params as some models might not like them
102
- params = {k: v for k, v in params.items() if v is not None}
103
 
104
- print(f"DEBUG: image_services.py - Calling HF Image API ({model_id}) with prompt: {prompt[:50]}...")
105
  try:
106
- # The `text_to_image` method of InferenceClient returns a PIL Image directly
107
  image_result: Image.Image = hf_inference_image_client.text_to_image(
108
  prompt,
109
  model=model_id,
110
  **params
111
  )
 
 
112
  print(f"DEBUG: image_services.py - HF Image API ({model_id}) image generated successfully.")
113
  return ImageGenResponse(image=image_result, provider="HF Image API", model_id_used=model_id)
114
  except Exception as e:
115
  error_msg = f"HF Image API Error ({model_id}): {type(e).__name__} - {str(e)}"
116
- # Check for common HF API errors
117
- if "Rate limit reached" in str(e):
118
- error_msg += " You may have hit free tier limits."
119
- elif "Model is currently loading" in str(e):
120
- error_msg += " The model might be loading, please try again in a moment."
121
- elif "Authorization header is correct" in str(e) or "401" in str(e):
122
- error_msg += " Issue with your HF_TOKEN authentication."
123
-
124
  print(f"ERROR: image_services.py - {error_msg}")
125
  return ImageGenResponse(error=error_msg, success=False, provider="HF Image API", model_id_used=model_id, raw_response=e)
126
 
127
- print("DEBUG: core.image_services (for StoryVerseWeaver) - Module defined with HF Image support.")
 
1
  # storyverse_weaver/core/image_services.py
2
  import os
3
+ import base64 # Still useful if HF API ever returns b64
 
4
  from io import BytesIO
5
  from PIL import Image
6
+ from huggingface_hub import InferenceClient # Main client for HF models
7
 
8
  # --- API Key Configuration ---
9
+ HF_TOKEN = os.getenv("STORYVERSE_HF_TOKEN") # This is the key we'll use
 
 
10
 
11
+ HF_IMAGE_API_CONFIGURED = False
12
+ hf_inference_image_client = None
 
13
 
14
+ class ImageGenResponse:
15
+ def __init__(self, image: Image.Image = None, image_url: str = None,
16
+ error: str = None, success: bool = True,
17
+ provider: str = "HF Image API", model_id_used: str = None):
18
  self.image, self.image_url, self.error, self.success, self.provider, self.model_id_used = \
19
  image, image_url, error, success, provider, model_id_used
20
 
21
+ def initialize_image_llms(): # Renamed for consistency, though it's not an "LLM" for images
22
+ global HF_IMAGE_API_CONFIGURED, hf_inference_image_client, HF_TOKEN
23
+
24
+ HF_TOKEN = os.getenv("STORYVERSE_HF_TOKEN") # Ensure it's loaded here too
25
+ print("INFO: image_services.py - Initializing Image Generation services (HF Focus)...")
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  if HF_TOKEN and HF_TOKEN.strip():
28
  try:
29
+ hf_inference_image_client = InferenceClient(token=HF_TOKEN)
30
+ # Optional: Test with a quick model ping if desired, but client init is usually enough
31
+ # For instance, try to get model info for a known image model if API allows
32
+ # Or assume it's ready if client initializes without error.
33
  HF_IMAGE_API_CONFIGURED = True
34
  print("SUCCESS: image_services.py - Hugging Face InferenceClient (for images) ready.")
35
  except Exception as e:
36
  HF_IMAGE_API_CONFIGURED = False
37
+ print(f"ERROR: image_services.py - Failed to initialize HF InferenceClient for images: {type(e).__name__} - {e}")
38
+ hf_inference_image_client = None
39
  else:
40
  HF_IMAGE_API_CONFIGURED = False
41
+ print("WARNING: image_services.py - STORYVERSE_HF_TOKEN not found or empty. HF Image models disabled.")
42
 
43
+ print(f"INFO: image_services.py - Image Service Init complete. HF Image API Configured: {HF_IMAGE_API_CONFIGURED}")
44
 
45
+ def is_hf_image_api_ready(): # Getter function for app.py
46
+ global HF_IMAGE_API_CONFIGURED
47
+ return HF_IMAGE_API_CONFIGURED
48
 
49
+ # --- Hugging Face Image Model via Inference API ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  def generate_image_hf_model(prompt: str,
51
+ model_id: str = "stabilityai/stable-diffusion-xl-base-1.0", # Default popular model
 
 
52
  negative_prompt: str = None,
53
+ height: int = 768, # Common for SDXL
54
+ width: int = 768, # Common for SDXL
55
  num_inference_steps: int = 25,
56
+ guidance_scale: float = 7.0 # Lower can be more creative, higher more prompt-adherent
57
+ ) -> ImageGenResponse:
58
+ global hf_inference_image_client
59
+ if not is_hf_image_api_ready() or not hf_inference_image_client:
60
+ return ImageGenResponse(error="Hugging Face API (for images) not configured.", success=False, model_id_used=model_id)
61
 
62
  params = {
63
  "negative_prompt": negative_prompt,
 
66
  "num_inference_steps": num_inference_steps,
67
  "guidance_scale": guidance_scale
68
  }
69
+ params = {k: v for k, v in params.items() if v is not None} # Clean out None params
 
70
 
71
+ print(f"DEBUG: image_services.py - Calling HF Image API ({model_id}) with prompt: {prompt[:70]}...")
72
  try:
73
+ # InferenceClient's text_to_image method returns a PIL Image directly
74
  image_result: Image.Image = hf_inference_image_client.text_to_image(
75
  prompt,
76
  model=model_id,
77
  **params
78
  )
79
+ # Some models might be on serverless inference endpoints that take longer
80
+ # The default timeout for InferenceClient is usually reasonable.
81
  print(f"DEBUG: image_services.py - HF Image API ({model_id}) image generated successfully.")
82
  return ImageGenResponse(image=image_result, provider="HF Image API", model_id_used=model_id)
83
  except Exception as e:
84
  error_msg = f"HF Image API Error ({model_id}): {type(e).__name__} - {str(e)}"
85
+ if "Rate limit reached" in str(e): error_msg += " You may have hit free tier limits."
86
+ elif "Model is currently loading" in str(e) or "estimated_time" in str(e).lower(): error_msg += " Model may be loading, try again in a moment."
87
+ elif "Authorization" in str(e) or "401" in str(e): error_msg += " Authentication issue with your HF_TOKEN."
88
+ elif "does not seem to support task text-to-image" in str(e): error_msg = f"Model {model_id} may not support text-to-image or is misconfigured."
89
+
 
 
 
90
  print(f"ERROR: image_services.py - {error_msg}")
91
  return ImageGenResponse(error=error_msg, success=False, provider="HF Image API", model_id_used=model_id, raw_response=e)
92
 
93
+ print("DEBUG: core.image_services (HF Focus for StoryVerseWeaver) - Module defined.")