Spaces:
Sleeping
Sleeping
File size: 6,622 Bytes
7c7ef49 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
from google import genai
from google.genai import types
from typing import Union, List, Generator, Dict, Optional
from PIL import Image
from io import BytesIO
import base64
import requests
import asyncio
import os
from dotenv import load_dotenv
from .category_instructions import get_instruction_for_category
from .category_config import CATEGORY_CONFIGS
load_dotenv()
client = genai.Client(
api_key=os.getenv("API_KEY")
)
def bytes_to_base64(data: bytes, with_prefix: bool = True) -> str:
encoded = base64.b64encode(data).decode("utf-8")
return f"data:image/png;base64,{encoded}" if with_prefix else encoded
def decode_base64_image(base64_str: str) -> Image.Image:
# Remove the prefix if present (e.g., "data:image/png;base64,")
if base64_str.startswith("data:image"):
base64_str = base64_str.split(",")[1]
image_data = base64.b64decode(base64_str)
image = Image.open(BytesIO(image_data))
return image
async def async_generate_text_and_image(prompt, category: Optional[str] = None):
# Get the appropriate instruction and configuration
instruction = get_instruction_for_category(category)
config = CATEGORY_CONFIGS.get(category.lower() if category else "", {})
# Enhance the prompt with category-specific guidance if available
if config:
style_guide = config.get("style_guide", "")
conventions = config.get("conventions", [])
common_elements = config.get("common_elements", [])
enhanced_prompt = (
f"{instruction}\n\n"
f"Style Guide: {style_guide}\n"
f"Drawing Conventions to Follow:\n- " + "\n- ".join(conventions) + "\n"
f"Consider Including These Elements:\n- " + "\n- ".join(common_elements) + "\n\n"
f"User Request: {prompt}"
)
else:
enhanced_prompt = f"{instruction}\n\nUser Request: {prompt}"
response = await client.aio.models.generate_content(
model=os.getenv("MODEL"),
contents=enhanced_prompt,
config=types.GenerateContentConfig(
response_modalities=['TEXT', 'IMAGE']
)
)
for part in response.candidates[0].content.parts:
if hasattr(part, 'text') and part.text is not None:
# Try to parse the text into sections
try:
text_sections = {}
current_section = "overview"
lines = part.text.split('\n')
for line in lines:
line = line.strip()
if not line:
continue
# Check for section headers
if any(line.lower().startswith(f"{i}.") for i in range(1, 6)):
section_name = line.split('.', 1)[1].split(':', 1)[0].strip().lower()
section_name = section_name.replace(' ', '_')
current_section = section_name
text_sections[current_section] = []
else:
if current_section not in text_sections:
text_sections[current_section] = []
text_sections[current_section].append(line)
# Clean up the sections
for section in text_sections:
text_sections[section] = '\n'.join(text_sections[section]).strip()
yield {'type': 'text', 'data': text_sections}
except Exception as e:
# Fallback to raw text if parsing fails
yield {'type': 'text', 'data': {'raw_text': part.text}}
elif hasattr(part, 'inline_data') and part.inline_data is not None:
yield {'type': 'image', 'data': bytes_to_base64(part.inline_data.data)}
async def async_generate_with_image_input(text: Optional[str], image_path: str, category: Optional[str] = None):
# Validate that the image input is a base64 data URI
if not isinstance(image_path, str) or not image_path.startswith("data:image/"):
raise ValueError("Invalid image input: expected a base64 Data URI starting with 'data:image/'")
# Decode the base64 string into a PIL Image
image = decode_base64_image(image_path)
# Get the appropriate instruction for the category
instruction = get_instruction_for_category(category)
contents = []
if text:
# Combine the instruction with the user's text input
combined_text = f"{instruction}\n\nUser Request: {text}"
contents.append(combined_text)
else:
contents.append(instruction)
contents.append(image)
response = await client.aio.models.generate_content(
model=os.getenv("MODEL"),
contents=contents,
config=types.GenerateContentConfig(
response_modalities=['TEXT', 'IMAGE']
)
)
for part in response.candidates[0].content.parts:
if hasattr(part, 'text') and part.text is not None:
# Try to parse the text into sections
try:
text_sections = {}
current_section = "overview"
lines = part.text.split('\n')
for line in lines:
line = line.strip()
if not line:
continue
# Check for section headers
if any(line.lower().startswith(f"{i}.") for i in range(1, 6)):
section_name = line.split('.', 1)[1].split(':', 1)[0].strip().lower()
section_name = section_name.replace(' ', '_')
current_section = section_name
text_sections[current_section] = []
else:
if current_section not in text_sections:
text_sections[current_section] = []
text_sections[current_section].append(line)
# Clean up the sections
for section in text_sections:
text_sections[section] = '\n'.join(text_sections[section]).strip()
yield {'type': 'text', 'data': text_sections}
except Exception as e:
# Fallback to raw text if parsing fails
yield {'type': 'text', 'data': {'raw_text': part.text}}
elif hasattr(part, 'inline_data') and part.inline_data is not None:
yield {'type': 'image', 'data': bytes_to_base64(part.inline_data.data)} |