baconnier commited on
Commit
05f2374
·
verified ·
1 Parent(s): f9a0b92

Upload 10 files

Browse files
Files changed (10) hide show
  1. README.md +5 -4
  2. api.py +145 -0
  3. app.py +30 -0
  4. gitattributes +35 -0
  5. models.py +50 -0
  6. prompts.py +186 -0
  7. requirements.txt +11 -0
  8. schemas.py +79 -0
  9. ui.py +480 -0
  10. ui_old.py +346 -0
README.md CHANGED
@@ -1,12 +1,13 @@
1
  ---
2
- title: Paint
3
- emoji: 🚀
4
- colorFrom: indigo
5
- colorTo: pink
6
  sdk: gradio
7
  sdk_version: 5.5.0
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Prompt Image
3
+ emoji: 🐨
4
+ colorFrom: blue
5
+ colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 5.5.0
8
  app_file: app.py
9
  pinned: false
10
+ hf_oauth: false
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
api.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ from openai import OpenAI
4
+ from typing import Dict, Any, Optional
5
+ import gradio as gr
6
+ from prompts import PROMPT_ANALYZER_TEMPLATE
7
+ import time
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ FALLBACK_MODELS = [
12
+ "mixtral-8x7b-32768",
13
+ "llama-3.1-70b-versatile",
14
+ "llama-3.1-8b-instant",
15
+ "llama3-70b-8192",
16
+ "llama3-8b-8192"
17
+ ]
18
+
19
+ class ModelManager:
20
+ def __init__(self):
21
+ self.current_model_index = 0
22
+ self.max_retries = len(FALLBACK_MODELS)
23
+
24
+ @property
25
+ def current_model(self) -> str:
26
+ return FALLBACK_MODELS[self.current_model_index]
27
+
28
+ def next_model(self) -> str:
29
+ self.current_model_index = (self.current_model_index + 1) % len(FALLBACK_MODELS)
30
+ logger.info(f"Switching to model: {self.current_model}")
31
+ return self.current_model
32
+
33
+ class PromptEnhancementAPI:
34
+ def __init__(self, api_key: str, base_url: Optional[str] = None):
35
+ self.client = OpenAI(
36
+ api_key=api_key,
37
+ base_url=base_url or "https://api.groq.com/openai/v1"
38
+ )
39
+ self.model_manager = ModelManager()
40
+
41
+ def _try_parse_json(self, content: str, retries: int = 0) -> Dict[str, Any]:
42
+ try:
43
+ result = json.loads(content.strip().lstrip('\n'))
44
+ if not isinstance(result, dict):
45
+ raise ValueError("Response is not a valid JSON object")
46
+ return result
47
+ except (json.JSONDecodeError, ValueError) as e:
48
+ if retries < self.model_manager.max_retries - 1:
49
+ logger.warning(f"JSON parsing failed with model {self.model_manager.current_model}. Switching models...")
50
+ self.model_manager.next_model()
51
+ raise e
52
+ logger.error(f"JSON parsing failed with all models: {str(e)}")
53
+ raise
54
+
55
+ def generate_enhancement(self, system_prompt: str, user_prompt: str, user_directive: str = "", state: Optional[Dict] = None) -> Dict[str, Any]:
56
+ retries = 0
57
+ last_error = None
58
+
59
+ while retries < self.model_manager.max_retries:
60
+ try:
61
+ messages = [
62
+ {"role": "system", "content": system_prompt},
63
+ {"role": "user", "content": user_prompt}
64
+ ]
65
+
66
+ if user_directive:
67
+ messages.append({"role": "user", "content": f"User directive: {user_directive}"})
68
+
69
+ if state:
70
+ messages.append({
71
+ "role": "assistant",
72
+ "content": json.dumps(state)
73
+ })
74
+
75
+ response = self.client.chat.completions.create(
76
+ model=self.model_manager.current_model,
77
+ messages=messages,
78
+ temperature=0.7,
79
+ max_tokens=4000,
80
+ response_format={"type": "json_object"}
81
+ )
82
+
83
+ result = self._try_parse_json(response.choices[0].message.content, retries)
84
+ return result
85
+
86
+ except (json.JSONDecodeError, ValueError) as e:
87
+ last_error = e
88
+ retries += 1
89
+ if retries < self.model_manager.max_retries:
90
+ logger.warning(f"Attempt {retries} failed. Switching models and retrying...")
91
+ time.sleep(1) # Brief pause before retry
92
+ continue
93
+ break
94
+
95
+ except Exception as e:
96
+ logger.error(f"API error: {str(e)}")
97
+ if "rate limit" in str(e).lower():
98
+ if retries < self.model_manager.max_retries - 1:
99
+ self.model_manager.next_model()
100
+ retries += 1
101
+ time.sleep(1)
102
+ continue
103
+ raise gr.Error(f"API request failed: {str(e)}")
104
+
105
+ logger.error(f"All models failed to generate valid JSON: {str(last_error)}")
106
+ return create_error_response(user_prompt, user_directive)
107
+
108
+ class PromptEnhancementSystem:
109
+ def __init__(self, api_key: str, base_url: Optional[str] = None):
110
+ self.api = PromptEnhancementAPI(api_key, base_url)
111
+ self.current_state = None
112
+ self.history = []
113
+
114
+ def start_session(self, prompt: str, user_directive: str = "") -> Dict[str, Any]:
115
+ formatted_system_prompt = PROMPT_ANALYZER_TEMPLATE.format(
116
+ input_prompt=prompt,
117
+ user_directive=user_directive
118
+ )
119
+
120
+ result = self.api.generate_enhancement(
121
+ system_prompt=formatted_system_prompt,
122
+ user_prompt=prompt,
123
+ user_directive=user_directive
124
+ )
125
+
126
+ self.current_state = result
127
+ self.history = [result]
128
+ return result
129
+
130
+ def apply_enhancement(self, choice: str, user_directive: str = "") -> Dict[str, Any]:
131
+ formatted_system_prompt = PROMPT_ANALYZER_TEMPLATE.format(
132
+ input_prompt=choice,
133
+ user_directive=user_directive
134
+ )
135
+
136
+ result = self.api.generate_enhancement(
137
+ system_prompt=formatted_system_prompt,
138
+ user_prompt=choice,
139
+ user_directive=user_directive,
140
+ state=self.current_state
141
+ )
142
+
143
+ self.current_state = result
144
+ self.history.append(result)
145
+ return result
app.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ from ui import create_interface
4
+ from huggingface_hub import login
5
+
6
+ # Setup logging
7
+ logging.basicConfig(level=logging.INFO)
8
+ logger = logging.getLogger(__name__)
9
+
10
+ # Environment variables check
11
+ required_vars = ["HF_TOKEN", "GROQ_API_KEY"]
12
+ missing_vars = [var for var in required_vars if not os.getenv(var)]
13
+ if missing_vars:
14
+ raise ValueError(f"Missing required environment variables: {', '.join(missing_vars)}")
15
+
16
+ # Hugging Face login
17
+ try:
18
+ login(token=os.getenv("HF_TOKEN"))
19
+ logger.info("Successfully logged in to Hugging Face")
20
+ except Exception as e:
21
+ logger.error(f"Failed to login to Hugging Face: {str(e)}")
22
+ raise
23
+
24
+ if __name__ == "__main__":
25
+ try:
26
+ demo = create_interface()
27
+ demo.queue(max_size=5).launch()
28
+ except Exception as e:
29
+ logger.error(f"Application startup error: {str(e)}")
30
+ raise
gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
models.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, Field, field_validator
2
+ from typing import List, Dict, Any
3
+
4
+ class ProgressMeters(BaseModel):
5
+ technical_detail: int = Field(default=0, ge=0, le=100)
6
+ artistic_style: int = Field(default=0, ge=0, le=100)
7
+ composition: int = Field(default=0, ge=0, le=100)
8
+ context: int = Field(default=0, ge=0, le=100)
9
+
10
+ class SubjectAnalysis(BaseModel):
11
+ clarity: int = Field(default=0, ge=0, le=100)
12
+ details_present: List[str] = []
13
+ details_missing: List[str] = []
14
+
15
+ class StyleEvaluation(BaseModel):
16
+ defined_elements: List[str] = []
17
+ missing_elements: List[str] = []
18
+ style_score: int = Field(default=0, ge=0, le=100)
19
+
20
+ class TechnicalAssessment(BaseModel):
21
+ specified_elements: List[str] = []
22
+ missing_elements: List[str] = []
23
+ technical_score: int = Field(default=0, ge=0, le=100)
24
+
25
+ class CompositionReview(BaseModel):
26
+ strengths: List[str] = []
27
+ weaknesses: List[str] = []
28
+ composition_score: int = Field(default=0, ge=0, le=100)
29
+
30
+ class InitialAnalysis(BaseModel):
31
+ subject_analysis: SubjectAnalysis = SubjectAnalysis()
32
+ style_evaluation: StyleEvaluation = StyleEvaluation()
33
+ technical_assessment: TechnicalAssessment = TechnicalAssessment()
34
+ composition_review: CompositionReview = CompositionReview()
35
+
36
+ class EnhancedVersion(BaseModel):
37
+ focus_area: str = ""
38
+ enhanced_prompt: str = ""
39
+ improvement_score: int = Field(default=0, ge=0, le=100)
40
+
41
+ class PromptAnalysis(BaseModel):
42
+ initial_analysis: InitialAnalysis = InitialAnalysis()
43
+ enhanced_versions: List[EnhancedVersion] = []
44
+ session_state: Dict[str, Any] = {}
45
+
46
+ @field_validator('enhanced_versions', mode='before')
47
+ def validate_enhanced_versions(cls, v):
48
+ if not isinstance(v, list):
49
+ return []
50
+ return v
prompts.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ PROMPT_ANALYZER_TEMPLATE = '''You are a Prompt Enhancement Specialist for image generation. Your task is to analyze a given prompt and dynamically determine the most relevant improvement axes based on the current analysis, while ensuring compliance with specific user directives.
2
+
3
+ For the following prompt and user directive:
4
+ <input_prompt>
5
+ {input_prompt}
6
+ </input_prompt>
7
+
8
+ <user_directive>
9
+ {user_directive}
10
+ </user_directive>
11
+
12
+ 1. Initial Analysis (Comprehensive evaluation of current elements):
13
+
14
+ Subject Analysis:
15
+ - Main subject identification and clarity
16
+ - Subject details and characteristics
17
+ - Secondary elements and their relationship
18
+ - Scale and proportions
19
+
20
+ Style Elements:
21
+ - Artistic style presence/absence
22
+ - Medium specification
23
+ - Art movement references
24
+ - Artist influences
25
+ - Historical or cultural context
26
+
27
+ Technical Specifications:
28
+ - Lighting details
29
+ - Color palette
30
+ - Texture information
31
+ - Resolution indicators
32
+ - Camera angle/perspective
33
+ - Shot type/framing
34
+
35
+ Compositional Elements:
36
+ - Spatial arrangement
37
+ - Foreground/background balance
38
+ - Rule of thirds consideration
39
+ - Leading lines
40
+ - Focal point clarity
41
+
42
+ Environmental Context:
43
+ - Setting details
44
+ - Time period
45
+ - Weather/atmospheric conditions
46
+ - Environmental interaction
47
+ - Scene depth
48
+
49
+ Mood and Atmosphere:
50
+ - Emotional tone
51
+ - Atmospheric qualities
52
+ - Dynamic vs static elements
53
+ - Story/narrative elements
54
+ - Symbolic elements
55
+
56
+ 2. Limitations Assessment:
57
+ - Missing critical details
58
+ - Ambiguous elements
59
+ - Technical omissions
60
+ - Stylistic gaps
61
+ - Compositional weaknesses
62
+ - Context deficiencies
63
+ - Mood/atmosphere undefined areas
64
+
65
+ 3. Improvement Axes (Select 4 most impactful):
66
+ For each axis, consider:
67
+ - Impact on visual outcome
68
+ - Technical feasibility
69
+ - AI model capabilities
70
+ - Balance between specificity and creativity
71
+ - Enhancement of original vision
72
+ - Visual interest addition
73
+ - Technical precision improvement
74
+ - User directive compliance and integration
75
+ - ...
76
+
77
+ 4. Enhancement Strategy:
78
+ For each improvement axis:
79
+ - Specific terminology to add
80
+ - Technical parameters to include
81
+ - Stylistic elements to incorporate
82
+ - Compositional guidance
83
+ - Atmospheric elements
84
+ - Reference points (artists, styles, techniques)
85
+ - User directive implementation methods
86
+
87
+ Now provide your analysis in this JSON structure:
88
+
89
+ {{
90
+ "initial_analysis": {{
91
+ "initial_prompt": {input_prompt},
92
+ "user_directive": {user_directive},
93
+ "directive_impact_assessment": {{
94
+ "feasibility": string,
95
+ "integration_approach": string,
96
+ "potential_conflicts": [string],
97
+ "resolution_strategy": string
98
+ }},
99
+ "subject_analysis": {{
100
+ "score": integer(0-100),
101
+ "strengths": [string],
102
+ "weaknesses": [string]
103
+ }},
104
+ "style_evaluation": {{
105
+ "score": integer(0-100),
106
+ "strengths": [string],
107
+ "weaknesses": [string]
108
+ }},
109
+ "technical_assessment": {{
110
+ "score": integer(0-100),
111
+ "strengths": [string],
112
+ "weaknesses": [string]
113
+ }},
114
+ "composition_review": {{
115
+ "score": integer(0-100),
116
+ "strengths": [string],
117
+ "weaknesses": [string]
118
+ }},
119
+ "context_evaluation": {{
120
+ "score": integer(0-100),
121
+ "strengths": [string],
122
+ "weaknesses": [string]
123
+ }},
124
+ "mood_assessment": {{
125
+ "score": integer(0-100),
126
+ "strengths": [string],
127
+ "weaknesses": [string]
128
+ }}
129
+ }},
130
+ "improvement_axes": [
131
+ {{
132
+ "axis_name": string,
133
+ "focus_area": string,
134
+ "version": integer,
135
+ "score": integer(0-100),
136
+ "current_state": string,
137
+ "directive_alignment": string,
138
+ "recommended_additions": [string],
139
+ "expected_impact": string,
140
+ "technical_considerations": [string],
141
+ "enhanced_prompt": string,
142
+ "expected_improvements": [string]
143
+ }}
144
+ ],
145
+ "technical_recommendations": {{
146
+ "style_keywords": [string],
147
+ "composition_tips": [string],
148
+ "negative_prompt_suggestions": [string],
149
+ "directive_specific_adjustments": [string]
150
+ }}
151
+ }}
152
+
153
+ Guidelines for Dynamic Enhancement:
154
+ 1. Analyze current scores to identify weakest areas
155
+ 2. Ensure all improvements align with the user directive (if provided)
156
+ 3. Consider improvement potential for each axis
157
+ 4. Select 4 most impactful axes based on:
158
+ - User directive compliance (highest priority if provided)
159
+ - Current analysis scores
160
+ - Previous improvements
161
+ - Remaining potential
162
+ - Overall image quality goals
163
+ 5. Generate targeted enhancements for selected axes
164
+
165
+ Remember to:
166
+ - Prioritize user directive implementation while maintaining prompt integrity
167
+ - Keep improvements relevant to image generation
168
+ - Maintain the original intent of the prompt
169
+ - Be specific and detailed in suggestions
170
+ - Ensure each enhanced version builds on the original
171
+ - Focus on visual elements that AI image generators understand
172
+ - Consider technical aspects like lighting, composition, and style
173
+ - Add specific artistic references when relevant
174
+ - Balance detail with creativity
175
+ - Consider AI model capabilities and limitations
176
+ - Provide practical composition guidance
177
+ - Include relevant style keywords
178
+ - Specify negative prompt elements
179
+
180
+ Each iteration should:
181
+ 1. Verify user directive compliance
182
+ 2. Reassess current state
183
+ 3. Identify new priority areas
184
+ 4. Generate fresh improvement approaches
185
+ 5. Build upon previous enhancements while maintaining user directive alignment
186
+ 6. Maintain coherence with original concept'''
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate
2
+ git+https://github.com/huggingface/diffusers.git
3
+ invisible_watermark
4
+ torch
5
+ transformers==4.42.4
6
+ xformers
7
+ sentencepiece
8
+ gradio==4.14.0
9
+ numpy==1.24.3
10
+ openai==1.3.0
11
+ huggingface-hub>=0.19.0
schemas.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Dict, Any
2
+ from pydantic import BaseModel, Field, ConfigDict
3
+
4
+ class DirectiveImpactAssessment(BaseModel):
5
+ feasibility: str = Field(default="Not assessed")
6
+ integration_approach: str = Field(default="Not determined")
7
+ potential_conflicts: List[str] = Field(default_factory=lambda: ["None identified"])
8
+ resolution_strategy: str = Field(default="Not required")
9
+
10
+ class AnalysisScore(BaseModel):
11
+ score: int = Field(default=0, ge=0, le=100)
12
+ strengths: List[str] = Field(default_factory=lambda: ["Not analyzed"])
13
+ weaknesses: List[str] = Field(default_factory=lambda: ["Not analyzed"])
14
+
15
+ class ImprovementAxis(BaseModel):
16
+ axis_name: str = Field(default="Default")
17
+ focus_area: str = Field(default="Not specified")
18
+ version: int = Field(default=1)
19
+ score: int = Field(default=0, ge=0, le=100)
20
+ current_state: str = Field(default="Not evaluated")
21
+ directive_alignment: str = Field(default="Not aligned")
22
+ recommended_additions: List[str] = Field(default_factory=lambda: ["No recommendations"])
23
+ expected_impact: str = Field(default="Not determined")
24
+ technical_considerations: List[str] = Field(default_factory=lambda: ["None specified"])
25
+ enhanced_prompt: str = Field(default="")
26
+ expected_improvements: List[str] = Field(default_factory=lambda: ["None specified"])
27
+
28
+ class TechnicalRecommendations(BaseModel):
29
+ style_keywords: List[str] = Field(default_factory=lambda: ["None"])
30
+ composition_tips: List[str] = Field(default_factory=lambda: ["None"])
31
+ negative_prompt_suggestions: List[str] = Field(default_factory=lambda: ["None"])
32
+ directive_specific_adjustments: List[str] = Field(default_factory=lambda: ["None"])
33
+
34
+ class InitialAnalysis(BaseModel):
35
+ initial_prompt: str
36
+ user_directive: str = Field(default="")
37
+ directive_impact_assessment: DirectiveImpactAssessment = Field(default_factory=DirectiveImpactAssessment)
38
+ subject_analysis: AnalysisScore = Field(default_factory=AnalysisScore)
39
+ style_evaluation: AnalysisScore = Field(default_factory=AnalysisScore)
40
+ technical_assessment: AnalysisScore = Field(default_factory=AnalysisScore)
41
+ composition_review: AnalysisScore = Field(default_factory=AnalysisScore)
42
+ context_evaluation: AnalysisScore = Field(default_factory=AnalysisScore)
43
+ mood_assessment: AnalysisScore = Field(default_factory=AnalysisScore)
44
+
45
+ class APIResponse(BaseModel):
46
+ model_config = ConfigDict(populate_by_name=True)
47
+ initial_analysis: InitialAnalysis
48
+ improvement_axes: List[ImprovementAxis] = Field(default_factory=list)
49
+ technical_recommendations: TechnicalRecommendations = Field(default_factory=TechnicalRecommendations)
50
+
51
+ def create_error_response(user_prompt: str, user_directive: str = "") -> Dict[str, Any]:
52
+ """Create a standardized error response that complies with APIResponse model"""
53
+ return APIResponse(
54
+ initial_analysis=InitialAnalysis(
55
+ initial_prompt=user_prompt,
56
+ user_directive=user_directive
57
+ ),
58
+ improvement_axes=[
59
+ ImprovementAxis(
60
+ axis_name="Error",
61
+ focus_area="Error occurred",
62
+ version=1,
63
+ score=0,
64
+ current_state="Failed",
65
+ directive_alignment="Failed to assess",
66
+ recommended_additions=["Error processing prompt"],
67
+ expected_impact="None",
68
+ technical_considerations=["Error occurred"],
69
+ enhanced_prompt=user_prompt,
70
+ expected_improvements=["Error processing prompt"]
71
+ )
72
+ ],
73
+ technical_recommendations=TechnicalRecommendations(
74
+ style_keywords=["Error"],
75
+ composition_tips=["Error"],
76
+ negative_prompt_suggestions=["Error"],
77
+ directive_specific_adjustments=["Error"]
78
+ )
79
+ ).model_dump()
ui.py ADDED
@@ -0,0 +1,480 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import os
3
+ import gradio as gr
4
+ import random
5
+ import torch
6
+ import logging
7
+ import numpy as np
8
+ from typing import Dict, Any, List
9
+ from diffusers import DiffusionPipeline
10
+ from api import PromptEnhancementSystem
11
+
12
+ # Constants
13
+ MAX_SEED = np.iinfo(np.int32).max
14
+ MAX_IMAGE_SIZE = 2048
15
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
16
+ MODEL_ID = "black-forest-labs/FLUX.1-schnell"
17
+ DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
18
+
19
+ print(f"Using device: {DEVICE}")
20
+ logger = logging.getLogger(__name__)
21
+
22
+ # Initialize model
23
+ try:
24
+ print("Loading model...")
25
+ pipe = DiffusionPipeline.from_pretrained(
26
+ MODEL_ID,
27
+ torch_dtype=DTYPE
28
+ ).to(DEVICE)
29
+ print("Model loaded successfully")
30
+ logger.info("Model loaded successfully")
31
+ except Exception as e:
32
+ print(f"Failed to load model: {str(e)}")
33
+ logger.error(f"Failed to load model: {str(e)}")
34
+ raise
35
+
36
+ @spaces.GPU()
37
+ def generate_multiple_images_batch(
38
+ improvement_axes,
39
+ current_gallery,
40
+ seed=42,
41
+ randomize_seed=False,
42
+ width=512,
43
+ height=512,
44
+ num_inference_steps=4,
45
+ current_prompt="",
46
+ initial_prompt="",
47
+ progress=gr.Progress(track_tqdm=True)
48
+ ):
49
+ try:
50
+ # Use current_prompt if not empty, otherwise fall back to initial_prompt
51
+ input_prompt = current_prompt if current_prompt.strip() else initial_prompt
52
+
53
+ # Extract prompts from improvement axes or use the input prompt if no axes
54
+ prompts = [axis["enhanced_prompt"] for axis in improvement_axes if axis.get("enhanced_prompt")]
55
+ if not prompts and input_prompt:
56
+ prompts = [input_prompt]
57
+
58
+ if not prompts:
59
+ return [None] * 4 + [current_gallery] + [seed]
60
+
61
+ if randomize_seed:
62
+ current_seed = random.randint(0, MAX_SEED)
63
+ else:
64
+ current_seed = seed
65
+
66
+ print(f"Generating images with prompt: {input_prompt}")
67
+ print(f"Using seed: {current_seed}")
68
+
69
+ # Generate images with the selected prompt
70
+ generator = torch.Generator().manual_seed(current_seed)
71
+ images = pipe(
72
+ prompt=prompts,
73
+ width=width,
74
+ height=height,
75
+ num_inference_steps=num_inference_steps,
76
+ generator=generator,
77
+ guidance_scale=0.0
78
+ ).images
79
+
80
+ # Pad with None if we have fewer than 4 images
81
+ while len(images) < 4:
82
+ images.append(None)
83
+
84
+ # Update gallery with new images
85
+ current_gallery = current_gallery or []
86
+ new_gallery = current_gallery + [(img, f"Prompt: {prompt}") for img, prompt in zip(images, prompts) if img is not None]
87
+
88
+ print("All images generated successfully")
89
+ return images[:4] + [new_gallery] + [current_seed]
90
+
91
+ except Exception as e:
92
+ print(f"Image generation error: {str(e)}")
93
+ logger.error(f"Image generation error: {str(e)}")
94
+ raise
95
+
96
+ def handle_image_select(evt: gr.SelectData, improvement_axes_data):
97
+ try:
98
+ if improvement_axes_data and isinstance(improvement_axes_data, list):
99
+ selected_index = evt.index[1] if isinstance(evt.index, tuple) else evt.index
100
+ if selected_index < len(improvement_axes_data):
101
+ selected_prompt = improvement_axes_data[selected_index].get("enhanced_prompt", "")
102
+ return selected_prompt
103
+ return ""
104
+ except Exception as e:
105
+ print(f"Error in handle_image_select: {str(e)}")
106
+ return ""
107
+
108
+ def handle_gallery_select(evt: gr.SelectData, gallery_data):
109
+ try:
110
+ if gallery_data and isinstance(evt.index, int) and evt.index < len(gallery_data):
111
+ image, prompt = gallery_data[evt.index]
112
+ # Remove "Prompt: " prefix if it exists
113
+ prompt = prompt.replace("Prompt: ", "") if prompt else ""
114
+ return {"prompt": prompt}, prompt
115
+ return None, ""
116
+ except Exception as e:
117
+ print(f"Error in handle_gallery_select: {str(e)}")
118
+ return None, ""
119
+
120
+ def clear_gallery():
121
+ return [], None, None, None, None # Returns empty gallery and clears the 4 images
122
+
123
+ def zip_gallery_images(gallery):
124
+ try:
125
+ if not gallery:
126
+ return None
127
+
128
+ import io
129
+ import zipfile
130
+ from datetime import datetime
131
+ import numpy as np
132
+ from PIL import Image
133
+
134
+ # Create zip file in memory
135
+ zip_buffer = io.BytesIO()
136
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
137
+ filename = f"gallery_images_{timestamp}.zip"
138
+
139
+ with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
140
+ for i, (img_data, prompt) in enumerate(gallery):
141
+ try:
142
+ if img_data is not None:
143
+ # Convert numpy array to PIL Image if needed
144
+ if isinstance(img_data, np.ndarray):
145
+ img = Image.fromarray(np.uint8(img_data))
146
+ elif isinstance(img_data, Image.Image):
147
+ img = img_data
148
+ else:
149
+ print(f"Skipping image {i}: invalid type {type(img_data)}")
150
+ continue
151
+
152
+ # Save image to bytes
153
+ img_buffer = io.BytesIO()
154
+ img.save(img_buffer, format='PNG')
155
+ img_buffer.seek(0)
156
+
157
+ # Create filename with prompt
158
+ safe_prompt = "".join(c for c in prompt[:30] if c.isalnum() or c in (' ', '-', '_')).strip()
159
+ img_filename = f"image_{i+1}_{safe_prompt}.png"
160
+
161
+ # Add to zip
162
+ zip_file.writestr(img_filename, img_buffer.getvalue())
163
+ except Exception as img_error:
164
+ print(f"Error processing image {i}: {str(img_error)}")
165
+ continue
166
+
167
+ # Prepare zip for download
168
+ zip_buffer.seek(0)
169
+
170
+ # Return the file data and name
171
+ return {
172
+ "name": filename,
173
+ "data": zip_buffer.getvalue()
174
+ }
175
+
176
+ except Exception as e:
177
+ print(f"Error creating zip: {str(e)}")
178
+ return None
179
+
180
+
181
+ def create_interface():
182
+ print("Creating interface...")
183
+ api_key = os.getenv("GROQ_API_KEY")
184
+ base_url = os.getenv("API_BASE_URL")
185
+
186
+ if not api_key:
187
+ print("GROQ_API_KEY not found in environment variables")
188
+ raise ValueError("GROQ_API_KEY not found in environment variables")
189
+
190
+ system = PromptEnhancementSystem(api_key, base_url)
191
+ print("PromptEnhancementSystem initialized")
192
+
193
+ def update_interface(prompt, user_directive):
194
+ try:
195
+ print(f"\n=== Processing prompt: {prompt}")
196
+ print(f"User directive: {user_directive}")
197
+ state = system.start_session(prompt, user_directive)
198
+ improvement_axes = state.get("improvement_axes", [])
199
+ initial_analysis = state.get("initial_analysis", {})
200
+ enhanced_prompt = ""
201
+ if improvement_axes and len(improvement_axes) > 0:
202
+ enhanced_prompt = improvement_axes[0].get("enhanced_prompt", prompt)
203
+
204
+ button_updates = []
205
+ for i in range(4):
206
+ if i < len(improvement_axes):
207
+ focus_area = improvement_axes[i].get("focus_area", f"Option {i+1}")
208
+ button_updates.append(gr.update(visible=True, value=focus_area))
209
+ else:
210
+ button_updates.append(gr.update(visible=False))
211
+
212
+ return [prompt, enhanced_prompt] + [
213
+ initial_analysis.get(key, {}) for key in [
214
+ "subject_analysis",
215
+ "style_evaluation",
216
+ "technical_assessment",
217
+ "composition_review",
218
+ "context_evaluation",
219
+ "mood_assessment"
220
+ ]
221
+ ] + [
222
+ improvement_axes,
223
+ state.get("technical_recommendations", {}),
224
+ state
225
+ ] + button_updates
226
+
227
+ except Exception as e:
228
+ print(f"Error in update_interface: {str(e)}")
229
+ logger.error(f"Error in update_interface: {str(e)}")
230
+ empty_analysis = {"score": 0, "strengths": [], "weaknesses": ["Error occurred"]}
231
+ return [prompt, prompt] + [empty_analysis] * 6 + [{}, {}, {}] + [gr.update(visible=False)] * 4
232
+
233
+ def handle_option_click(option_num, input_prompt, current_text, user_directive):
234
+ try:
235
+ print(f"\n=== Processing option {option_num}")
236
+ state = system.current_state
237
+ if state and "improvement_axes" in state:
238
+ improvement_axes = state["improvement_axes"]
239
+ if option_num < len(improvement_axes):
240
+ selected_prompt = improvement_axes[option_num]["enhanced_prompt"]
241
+ return [
242
+ input_prompt,
243
+ selected_prompt,
244
+ state.get("initial_analysis", {}).get("subject_analysis", {}),
245
+ state.get("initial_analysis", {}).get("style_evaluation", {}),
246
+ state.get("initial_analysis", {}).get("technical_assessment", {}),
247
+ state.get("initial_analysis", {}).get("composition_review", {}),
248
+ state.get("initial_analysis", {}).get("context_evaluation", {}),
249
+ state.get("initial_analysis", {}).get("mood_assessment", {}),
250
+ improvement_axes,
251
+ state.get("technical_recommendations", {}),
252
+ state
253
+ ]
254
+ return handle_error()
255
+ except Exception as e:
256
+ print(f"Error in handle_option_click: {str(e)}")
257
+ logger.error(f"Error in handle_option_click: {str(e)}")
258
+ return handle_error()
259
+
260
+ def handle_error():
261
+ empty_analysis = {"score": 0, "strengths": [], "weaknesses": ["Error occurred"]}
262
+ return ["", "", empty_analysis, empty_analysis, empty_analysis, empty_analysis, empty_analysis, empty_analysis, [], {}, {}]
263
+
264
+ with gr.Blocks(
265
+ title="AI Prompt Enhancement System",
266
+ theme=gr.themes.Soft(),
267
+ css="footer {visibility: hidden}"
268
+ ) as interface:
269
+ gr.Markdown("# 🎨 AI Prompt Enhancement & Image Generation System")
270
+
271
+ with gr.TabItem("Images Generation"):
272
+ with gr.Row():
273
+ input_prompt = gr.Textbox(
274
+ label="Initial Prompt",
275
+ placeholder="Enter your prompt here...",
276
+ lines=3,
277
+ scale=1
278
+ )
279
+
280
+ with gr.Row():
281
+ user_directive = gr.Textbox(
282
+ label="User Directive",
283
+ placeholder="Enter specific requirements...",
284
+ lines=2,
285
+ scale=1
286
+ )
287
+
288
+ with gr.Row():
289
+ start_btn = gr.Button("Start Enhancement", variant="primary")
290
+ with gr.Row():
291
+ current_prompt = gr.Textbox(
292
+ label="Current Prompt",
293
+ lines=3,
294
+ scale=1,
295
+ interactive=True
296
+ )
297
+ with gr.Row():
298
+ option_buttons = [gr.Button("", visible=False) for _ in range(4)]
299
+ with gr.Row():
300
+ finalize_btn = gr.Button("Generate Images", variant="primary")
301
+ with gr.Row():
302
+ generated_images = [
303
+ gr.Image(
304
+ label=f"Image {i+1}",
305
+ type="pil",
306
+ show_label=False,
307
+ height=256,
308
+ width=256,
309
+ interactive=False,
310
+ show_download_button=False,
311
+ elem_id=f"image_{i}"
312
+ ) for i in range(4)
313
+ ]
314
+
315
+ with gr.TabItem("Images Gallery"):
316
+ with gr.Row():
317
+ image_gallery = gr.Gallery(
318
+ label="Generated Images History",
319
+ show_label=False,
320
+ columns=4,
321
+ rows=None,
322
+ height=800,
323
+ object_fit="contain"
324
+ )
325
+ with gr.Row():
326
+ clear_gallery_btn = gr.Button("Clear Gallery", variant="secondary")
327
+ with gr.Row():
328
+ selected_image_data = gr.JSON(label="Selected Image Data", visible=True)
329
+ copy_to_prompt_btn = gr.Button("Copy Prompt to Current", visible=True)
330
+ with gr.TabItem("Image Generation Settings"):
331
+ with gr.Row():
332
+ seed = gr.Slider(
333
+ label="Seed",
334
+ minimum=0,
335
+ maximum=MAX_SEED,
336
+ step=1,
337
+ value=42
338
+ )
339
+ randomize_seed = gr.Checkbox(
340
+ label="Randomize seed",
341
+ value=True
342
+ )
343
+
344
+ with gr.Row():
345
+ width = gr.Slider(
346
+ label="Width",
347
+ minimum=256,
348
+ maximum=MAX_IMAGE_SIZE,
349
+ step=256,
350
+ value=512
351
+ )
352
+ height = gr.Slider(
353
+ label="Height",
354
+ minimum=256,
355
+ maximum=MAX_IMAGE_SIZE,
356
+ step=256,
357
+ value=512
358
+ )
359
+ num_inference_steps = gr.Slider(
360
+ label="Steps",
361
+ minimum=1,
362
+ maximum=50,
363
+ step=1,
364
+ value=4
365
+ )
366
+ with gr.TabItem("Initial Analysis"):
367
+ with gr.Row():
368
+ with gr.Column():
369
+ subject_analysis = gr.JSON(label="Subject Analysis")
370
+ with gr.Column():
371
+ style_evaluation = gr.JSON(label="Style Evaluation")
372
+ with gr.Column():
373
+ technical_assessment = gr.JSON(label="Technical Assessment")
374
+
375
+ with gr.Row():
376
+ with gr.Column():
377
+ composition_review = gr.JSON(label="Composition Review")
378
+ with gr.Column():
379
+ context_evaluation = gr.JSON(label="Context Evaluation")
380
+ with gr.Column():
381
+ mood_assessment = gr.JSON(label="Mood Assessment")
382
+
383
+ with gr.Accordion("Additional Information", open=False):
384
+ improvement_axes = gr.JSON(label="Improvement Axes")
385
+ technical_recommendations = gr.JSON(label="Technical Recommendations")
386
+ full_llm_response = gr.JSON(label="Full LLM Response")
387
+
388
+ # Add event handlers
389
+ for i, img in enumerate(generated_images):
390
+ img.select(
391
+ fn=handle_image_select,
392
+ inputs=[improvement_axes],
393
+ outputs=[current_prompt],
394
+ show_progress=False
395
+ )
396
+
397
+ start_btn.click(
398
+ update_interface,
399
+ inputs=[input_prompt, user_directive],
400
+ outputs=[
401
+ input_prompt,
402
+ current_prompt,
403
+ subject_analysis,
404
+ style_evaluation,
405
+ technical_assessment,
406
+ composition_review,
407
+ context_evaluation,
408
+ mood_assessment,
409
+ improvement_axes,
410
+ technical_recommendations,
411
+ full_llm_response
412
+ ] + option_buttons
413
+ )
414
+
415
+ for i, btn in enumerate(option_buttons):
416
+ btn.click(
417
+ handle_option_click,
418
+ inputs=[
419
+ gr.Slider(value=i, visible=False),
420
+ input_prompt,
421
+ current_prompt,
422
+ user_directive
423
+ ],
424
+ outputs=[
425
+ input_prompt,
426
+ current_prompt,
427
+ subject_analysis,
428
+ style_evaluation,
429
+ technical_assessment,
430
+ composition_review,
431
+ context_evaluation,
432
+ mood_assessment,
433
+ improvement_axes,
434
+ technical_recommendations,
435
+ full_llm_response
436
+ ]
437
+ )
438
+
439
+ finalize_btn.click(
440
+ generate_multiple_images_batch,
441
+ inputs=[
442
+ improvement_axes,
443
+ image_gallery,
444
+ seed,
445
+ randomize_seed,
446
+ width,
447
+ height,
448
+ num_inference_steps,
449
+ current_prompt,
450
+ input_prompt
451
+ ],
452
+ outputs=generated_images + [image_gallery] + [seed]
453
+ )
454
+
455
+ clear_gallery_btn.click(
456
+ clear_gallery,
457
+ inputs=[],
458
+ outputs=[image_gallery] + generated_images
459
+ )
460
+
461
+ # Add gallery selection handler
462
+ image_gallery.select(
463
+ fn=handle_gallery_select,
464
+ inputs=[image_gallery],
465
+ outputs=[selected_image_data, current_prompt]
466
+ )
467
+
468
+ # Add copy button handler
469
+ # Fix the copy button handler by adding a null check
470
+ copy_to_prompt_btn.click(
471
+ lambda x: x["prompt"] if x and isinstance(x, dict) and "prompt" in x else "",
472
+ inputs=[selected_image_data],
473
+ outputs=[current_prompt]
474
+ )
475
+ print("Interface setup complete")
476
+ return interface
477
+
478
+ if __name__ == "__main__":
479
+ interface = create_interface()
480
+ interface.launch()
ui_old.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import os
3
+ import gradio as gr
4
+ import random
5
+ import torch
6
+ import logging
7
+ import numpy as np
8
+ from typing import Dict, Any, List
9
+ from diffusers import DiffusionPipeline
10
+ from api import PromptEnhancementSystem
11
+
12
+ # Constants
13
+ MAX_SEED = np.iinfo(np.int32).max
14
+ MAX_IMAGE_SIZE = 2048
15
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
16
+ MODEL_ID = "black-forest-labs/FLUX.1-schnell"
17
+ DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
18
+
19
+ print(f"Using device: {DEVICE}")
20
+ logger = logging.getLogger(__name__)
21
+
22
+ # Initialize model
23
+ try:
24
+ print("Loading model...")
25
+ pipe = DiffusionPipeline.from_pretrained(
26
+ MODEL_ID,
27
+ torch_dtype=DTYPE
28
+ ).to(DEVICE)
29
+ print("Model loaded successfully")
30
+ logger.info("Model loaded successfully")
31
+ except Exception as e:
32
+ print(f"Failed to load model: {str(e)}")
33
+ logger.error(f"Failed to load model: {str(e)}")
34
+ raise
35
+
36
+ @spaces.GPU()
37
+ def generate_multiple_images_batch(
38
+ improvement_axes,
39
+ seed=42,
40
+ randomize_seed=False,
41
+ width=512,
42
+ height=512,
43
+ num_inference_steps=4,
44
+ progress=gr.Progress(track_tqdm=True)
45
+ ):
46
+ try:
47
+ # Extract prompts from improvement axes
48
+ prompts = [axis["enhanced_prompt"] for axis in improvement_axes if axis.get("enhanced_prompt")]
49
+
50
+ if not prompts:
51
+ return [None] * 4 + [seed]
52
+
53
+ if randomize_seed:
54
+ current_seed = random.randint(0, MAX_SEED)
55
+ else:
56
+ current_seed = seed
57
+
58
+ print(f"Generating images with {len(prompts)} prompts")
59
+ print(f"Using seed: {current_seed}")
60
+
61
+ # Generate all images in a single batch
62
+ generator = torch.Generator().manual_seed(current_seed)
63
+ images = pipe(
64
+ prompt=prompts, # Pass list of prompts directly
65
+ width=width,
66
+ height=height,
67
+ num_inference_steps=num_inference_steps,
68
+ generator=generator,
69
+ guidance_scale=0.0
70
+ ).images
71
+
72
+ # Pad with None if we have fewer than 4 images
73
+ while len(images) < 4:
74
+ images.append(None)
75
+
76
+ print("All images generated successfully")
77
+ return images[:4] + [current_seed]
78
+
79
+ except Exception as e:
80
+ print(f"Image generation error: {str(e)}")
81
+ logger.error(f"Image generation error: {str(e)}")
82
+ raise
83
+
84
+ def handle_image_select(evt: gr.SelectData, improvement_axes_data):
85
+ """Handle image selection event"""
86
+ try:
87
+ if improvement_axes_data and isinstance(improvement_axes_data, list):
88
+ selected_index = evt.index[1] if isinstance(evt.index, tuple) else evt.index
89
+ if selected_index < len(improvement_axes_data):
90
+ selected_prompt = improvement_axes_data[selected_index].get("enhanced_prompt", "")
91
+ return selected_prompt
92
+ return ""
93
+ except Exception as e:
94
+ print(f"Error in handle_image_select: {str(e)}")
95
+ return ""
96
+
97
+ def create_interface():
98
+ print("Creating interface...")
99
+ api_key = os.getenv("GROQ_API_KEY")
100
+ base_url = os.getenv("API_BASE_URL")
101
+
102
+ if not api_key:
103
+ print("GROQ_API_KEY not found in environment variables")
104
+ raise ValueError("GROQ_API_KEY not found in environment variables")
105
+
106
+ system = PromptEnhancementSystem(api_key, base_url)
107
+ print("PromptEnhancementSystem initialized")
108
+
109
+ def update_interface(prompt):
110
+ try:
111
+ print(f"\n=== Processing prompt: {prompt}")
112
+ state = system.start_session(prompt)
113
+
114
+ improvement_axes = state.get("improvement_axes", [])
115
+ initial_analysis = state.get("initial_analysis", {})
116
+
117
+ enhanced_prompt = ""
118
+ if improvement_axes and len(improvement_axes) > 0:
119
+ enhanced_prompt = improvement_axes[0].get("enhanced_prompt", prompt)
120
+
121
+ button_updates = []
122
+ for i in range(4):
123
+ if i < len(improvement_axes):
124
+ focus_area = improvement_axes[i].get("focus_area", f"Option {i+1}")
125
+ button_updates.append(gr.update(visible=True, value=focus_area))
126
+ else:
127
+ button_updates.append(gr.update(visible=False))
128
+
129
+ return [prompt, enhanced_prompt] + [
130
+ initial_analysis.get(key, {}) for key in [
131
+ "subject_analysis",
132
+ "style_evaluation",
133
+ "technical_assessment",
134
+ "composition_review",
135
+ "context_evaluation",
136
+ "mood_assessment"
137
+ ]
138
+ ] + [
139
+ improvement_axes,
140
+ state.get("technical_recommendations", {}),
141
+ None, None, None, None, # Four None values for the four image outputs
142
+ state
143
+ ] + button_updates
144
+ except Exception as e:
145
+ print(f"Error in update_interface: {str(e)}")
146
+ logger.error(f"Error in update_interface: {str(e)}")
147
+ empty_analysis = {"score": 0, "strengths": [], "weaknesses": ["Error occurred"]}
148
+ return [prompt, prompt] + [empty_analysis] * 6 + [{}, {}, None, None, None, None, {}] + [gr.update(visible=False)] * 4
149
+
150
+ def handle_option_click(option_num, input_prompt, current_text):
151
+ try:
152
+ print(f"\n=== Processing option {option_num}")
153
+ state = system.current_state
154
+ if state and "improvement_axes" in state:
155
+ improvement_axes = state["improvement_axes"]
156
+ if option_num < len(improvement_axes):
157
+ selected_prompt = improvement_axes[option_num]["enhanced_prompt"]
158
+ return [
159
+ input_prompt,
160
+ selected_prompt,
161
+ state.get("initial_analysis", {}).get("subject_analysis", {}),
162
+ state.get("initial_analysis", {}).get("style_evaluation", {}),
163
+ state.get("initial_analysis", {}).get("technical_assessment", {}),
164
+ state.get("initial_analysis", {}).get("composition_review", {}),
165
+ state.get("initial_analysis", {}).get("context_evaluation", {}),
166
+ state.get("initial_analysis", {}).get("mood_assessment", {}),
167
+ improvement_axes,
168
+ state.get("technical_recommendations", {}),
169
+ state
170
+ ]
171
+ return handle_error()
172
+ except Exception as e:
173
+ print(f"Error in handle_option_click: {str(e)}")
174
+ logger.error(f"Error in handle_option_click: {str(e)}")
175
+ return handle_error()
176
+
177
+ def handle_error():
178
+ empty_analysis = {"score": 0, "strengths": [], "weaknesses": ["Error occurred"]}
179
+ return ["", "", empty_analysis, empty_analysis, empty_analysis, empty_analysis, empty_analysis, empty_analysis, [], {}, {}]
180
+
181
+ with gr.Blocks(
182
+ title="AI Prompt Enhancement System",
183
+ theme=gr.themes.Soft(),
184
+ css="footer {visibility: hidden}"
185
+ ) as interface:
186
+ gr.Markdown("# 🎨 AI Prompt Enhancement & Image Generation System")
187
+
188
+ with gr.Row():
189
+ input_prompt = gr.Textbox(
190
+ label="Initial Prompt",
191
+ placeholder="Enter your prompt here...",
192
+ lines=3,
193
+ scale=1
194
+ )
195
+ current_prompt = gr.Textbox(
196
+ label="Current Prompt",
197
+ lines=3,
198
+ scale=1,
199
+ interactive=True
200
+ )
201
+
202
+ with gr.Row():
203
+ start_btn = gr.Button("Start Enhancement", variant="primary")
204
+
205
+ with gr.Row():
206
+ option_buttons = [gr.Button("", visible=False) for _ in range(4)]
207
+
208
+ with gr.Tabs():
209
+ with gr.TabItem("Initial Analysis"):
210
+ with gr.Row():
211
+ with gr.Column():
212
+ subject_analysis = gr.JSON(label="Subject Analysis")
213
+ with gr.Column():
214
+ style_evaluation = gr.JSON(label="Style Evaluation")
215
+ with gr.Column():
216
+ technical_assessment = gr.JSON(label="Technical Assessment")
217
+ with gr.Row():
218
+ with gr.Column():
219
+ composition_review = gr.JSON(label="Composition Review")
220
+ with gr.Column():
221
+ context_evaluation = gr.JSON(label="Context Evaluation")
222
+ with gr.Column():
223
+ mood_assessment = gr.JSON(label="Mood Assessment")
224
+
225
+ with gr.TabItem("Generated Images"):
226
+ with gr.Row():
227
+ generated_images = [
228
+ gr.Image(
229
+ label=f"Image {i+1}",
230
+ type="pil",
231
+ show_label=True,
232
+ height=256,
233
+ width=256,
234
+ interactive=True,
235
+ elem_id=f"image_{i}"
236
+ ) for i in range(4)
237
+ ]
238
+
239
+ with gr.Row():
240
+ finalize_btn = gr.Button("Generate All Images", variant="primary")
241
+
242
+
243
+ with gr.Accordion("Image Generation Settings", open=False):
244
+ with gr.Row():
245
+ seed = gr.Slider(
246
+ label="Seed",
247
+ minimum=0,
248
+ maximum=2048,
249
+ step=1,
250
+ value=42
251
+ )
252
+ randomize_seed = gr.Checkbox(
253
+ label="Randomize seed",
254
+ value=True
255
+ )
256
+ with gr.Row():
257
+ width = gr.Slider(
258
+ label="Width",
259
+ minimum=256,
260
+ maximum=2048,
261
+ step=256,
262
+ value=512
263
+ )
264
+ height = gr.Slider(
265
+ label="Height",
266
+ minimum=256,
267
+ maximum=2048,
268
+ step=256,
269
+ value=512
270
+ )
271
+ num_inference_steps = gr.Slider(
272
+ label="Steps",
273
+ minimum=1,
274
+ maximum=50,
275
+ step=1,
276
+ value=4
277
+ )
278
+
279
+ with gr.Accordion("Additional Information", open=False):
280
+ improvement_axes = gr.JSON(label="Improvement Axes")
281
+ technical_recommendations = gr.JSON(label="Technical Recommendations")
282
+ full_llm_response = gr.JSON(label="Full LLM Response")
283
+
284
+ # Add select events for each image
285
+ for i, img in enumerate(generated_images):
286
+ img.select(
287
+ fn=handle_image_select,
288
+ inputs=[improvement_axes],
289
+ outputs=[input_prompt]
290
+ )
291
+
292
+ start_btn.click(
293
+ update_interface,
294
+ inputs=[input_prompt],
295
+ outputs=[
296
+ input_prompt,
297
+ current_prompt,
298
+ subject_analysis,
299
+ style_evaluation,
300
+ technical_assessment,
301
+ composition_review,
302
+ context_evaluation,
303
+ mood_assessment,
304
+ improvement_axes,
305
+ technical_recommendations
306
+ ] + generated_images + [full_llm_response] + option_buttons
307
+ )
308
+
309
+ for i, btn in enumerate(option_buttons):
310
+ btn.click(
311
+ handle_option_click,
312
+ inputs=[
313
+ gr.Slider(value=i, visible=False),
314
+ input_prompt,
315
+ current_prompt
316
+ ],
317
+ outputs=[
318
+ input_prompt,
319
+ current_prompt,
320
+ subject_analysis,
321
+ style_evaluation,
322
+ technical_assessment,
323
+ composition_review,
324
+ context_evaluation,
325
+ mood_assessment,
326
+ improvement_axes,
327
+ technical_recommendations,
328
+ full_llm_response
329
+ ]
330
+ )
331
+
332
+ finalize_btn.click(
333
+ generate_multiple_images_batch,
334
+ inputs=[
335
+ improvement_axes,
336
+ seed,
337
+ randomize_seed,
338
+ width,
339
+ height,
340
+ num_inference_steps
341
+ ],
342
+ outputs=generated_images + [seed]
343
+ )
344
+
345
+ print("Interface setup complete")
346
+ return interface