File size: 6,622 Bytes
7c7ef49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
from google import genai
from google.genai import types
from typing import Union, List, Generator, Dict, Optional
from PIL import Image
from io import BytesIO
import base64
import requests
import asyncio
import os
from dotenv import load_dotenv
from .category_instructions import get_instruction_for_category
from .category_config import CATEGORY_CONFIGS
load_dotenv()

client = genai.Client(
    api_key=os.getenv("API_KEY")
)

def bytes_to_base64(data: bytes, with_prefix: bool = True) -> str:
    encoded = base64.b64encode(data).decode("utf-8")
    return f"data:image/png;base64,{encoded}" if with_prefix else encoded

def decode_base64_image(base64_str: str) -> Image.Image:
    # Remove the prefix if present (e.g., "data:image/png;base64,")
    if base64_str.startswith("data:image"):
        base64_str = base64_str.split(",")[1]
    image_data = base64.b64decode(base64_str)
    image = Image.open(BytesIO(image_data))
    return image

async def async_generate_text_and_image(prompt, category: Optional[str] = None):
    # Get the appropriate instruction and configuration
    instruction = get_instruction_for_category(category)
    config = CATEGORY_CONFIGS.get(category.lower() if category else "", {})
    
    # Enhance the prompt with category-specific guidance if available
    if config:
        style_guide = config.get("style_guide", "")
        conventions = config.get("conventions", [])
        common_elements = config.get("common_elements", [])
        
        enhanced_prompt = (
            f"{instruction}\n\n"
            f"Style Guide: {style_guide}\n"
            f"Drawing Conventions to Follow:\n- " + "\n- ".join(conventions) + "\n"
            f"Consider Including These Elements:\n- " + "\n- ".join(common_elements) + "\n\n"
            f"User Request: {prompt}"
        )
    else:
        enhanced_prompt = f"{instruction}\n\nUser Request: {prompt}"
    
    response = await client.aio.models.generate_content(
        model=os.getenv("MODEL"),
        contents=enhanced_prompt,
        config=types.GenerateContentConfig(
            response_modalities=['TEXT', 'IMAGE']
        )
    )
    for part in response.candidates[0].content.parts:
        if hasattr(part, 'text') and part.text is not None:
            # Try to parse the text into sections
            try:
                text_sections = {}
                current_section = "overview"
                lines = part.text.split('\n')
                
                for line in lines:
                    line = line.strip()
                    if not line:
                        continue
                    
                    # Check for section headers
                    if any(line.lower().startswith(f"{i}.") for i in range(1, 6)):
                        section_name = line.split('.', 1)[1].split(':', 1)[0].strip().lower()
                        section_name = section_name.replace(' ', '_')
                        current_section = section_name
                        text_sections[current_section] = []
                    else:
                        if current_section not in text_sections:
                            text_sections[current_section] = []
                        text_sections[current_section].append(line)
                
                # Clean up the sections
                for section in text_sections:
                    text_sections[section] = '\n'.join(text_sections[section]).strip()
                
                yield {'type': 'text', 'data': text_sections}
            except Exception as e:
                # Fallback to raw text if parsing fails
                yield {'type': 'text', 'data': {'raw_text': part.text}}
        elif hasattr(part, 'inline_data') and part.inline_data is not None:
            yield {'type': 'image', 'data': bytes_to_base64(part.inline_data.data)}

async def async_generate_with_image_input(text: Optional[str], image_path: str, category: Optional[str] = None):
    # Validate that the image input is a base64 data URI
    if not isinstance(image_path, str) or not image_path.startswith("data:image/"):
        raise ValueError("Invalid image input: expected a base64 Data URI starting with 'data:image/'")
    # Decode the base64 string into a PIL Image
    image = decode_base64_image(image_path)
    
    # Get the appropriate instruction for the category
    instruction = get_instruction_for_category(category)
    
    contents = []
    if text:
        # Combine the instruction with the user's text input
        combined_text = f"{instruction}\n\nUser Request: {text}"
        contents.append(combined_text)
    else:
        contents.append(instruction)
    contents.append(image)
    response = await client.aio.models.generate_content(
        model=os.getenv("MODEL"),
        contents=contents,
        config=types.GenerateContentConfig(
            response_modalities=['TEXT', 'IMAGE']
        )
    )
    for part in response.candidates[0].content.parts:
        if hasattr(part, 'text') and part.text is not None:
            # Try to parse the text into sections
            try:
                text_sections = {}
                current_section = "overview"
                lines = part.text.split('\n')
                
                for line in lines:
                    line = line.strip()
                    if not line:
                        continue
                    
                    # Check for section headers
                    if any(line.lower().startswith(f"{i}.") for i in range(1, 6)):
                        section_name = line.split('.', 1)[1].split(':', 1)[0].strip().lower()
                        section_name = section_name.replace(' ', '_')
                        current_section = section_name
                        text_sections[current_section] = []
                    else:
                        if current_section not in text_sections:
                            text_sections[current_section] = []
                        text_sections[current_section].append(line)
                
                # Clean up the sections
                for section in text_sections:
                    text_sections[section] = '\n'.join(text_sections[section]).strip()
                
                yield {'type': 'text', 'data': text_sections}
            except Exception as e:
                # Fallback to raw text if parsing fails
                yield {'type': 'text', 'data': {'raw_text': part.text}}
        elif hasattr(part, 'inline_data') and part.inline_data is not None:
            yield {'type': 'image', 'data': bytes_to_base64(part.inline_data.data)}