deki / app.py
orasul's picture
Fix bugs
02544bc
raw
history blame
14.8 kB
import gradio as gr
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
import base64
import time
import json
import logging
import tempfile
import uuid
import io
from PIL import Image
from openai import OpenAI
from ultralytics import YOLO
from wrapper import process_image_description
from utils.pills import preprocess_image
import cv2
import cv2.dnn_superres as dnn_superres
import easyocr
from spellchecker import SpellChecker
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')
GLOBAL_SR = None
GLOBAL_READER = None
GLOBAL_SPELL = None
YOLO_MODEL = None
def load_models():
"""
Called once to load all necessary models into memory.
"""
global GLOBAL_SR, GLOBAL_READER, GLOBAL_SPELL, YOLO_MODEL
logging.info("Loading all models...")
start_time_total = time.perf_counter()
# Super-resolution
logging.info("Loading super-resolution model...")
start_time = time.perf_counter()
sr = None
model_path = "EDSR_x4.pb"
if os.path.exists(model_path):
if hasattr(cv2, 'dnn_superres'):
try:
sr = dnn_superres.DnnSuperResImpl_create()
except AttributeError:
sr = dnn_superres.DnnSuperResImpl()
sr.readModel(model_path)
sr.setModel('edsr', 4)
GLOBAL_SR = sr
logging.info("Super-resolution model loaded.")
else:
logging.warning("cv2.dnn_superres module not available.")
else:
logging.warning(f"Super-resolution model file not found: {model_path}. Skipping SR.")
logging.info(f"Super-resolution init took {time.perf_counter()-start_time:.3f}s.")
# EasyOCR + SpellChecker
logging.info("Loading OCR + SpellChecker...")
start_time = time.perf_counter()
GLOBAL_READER = easyocr.Reader(['en'], gpu=True)
GLOBAL_SPELL = SpellChecker()
logging.info(f"OCR + SpellChecker init took {time.perf_counter()-start_time:.3f}s.")
# YOLO Model
logging.info("Loading YOLO model...")
start_time = time.perf_counter()
yolo_weights = "best.pt"
if os.path.exists(yolo_weights):
YOLO_MODEL = YOLO(yolo_weights)
logging.info("YOLO model loaded.")
else:
logging.error(f"YOLO weights file '{yolo_weights}' not found! Endpoints will fail.")
logging.info(f"YOLO init took {time.perf_counter()-start_time:.3f}s.")
logging.info(f"Total model loading time: {time.perf_counter()-start_time_total:.3f}s.")
def pil_to_base64_str(pil_image, format="PNG"):
"""Converts a PIL Image to a base64 string with a data URI header."""
buffered = io.BytesIO()
pil_image.save(buffered, format=format)
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
return f"data:image/{format.lower()};base64,{img_str}"
def save_base64_image(image_data: str, file_path: str):
"""Saves a base64 encoded image to a file."""
if image_data.startswith("data:image"):
_, image_data = image_data.split(",", 1)
img_bytes = base64.b64decode(image_data)
with open(file_path, "wb") as f:
f.write(img_bytes)
return img_bytes
def run_wrapper(image_path: str, output_dir: str, skip_ocr: bool = False, skip_spell: bool = False, json_mini=False) -> str:
"""Calls the main processing script and returns the result."""
process_image_description(
input_image=image_path,
weights_file="best.pt",
output_dir=output_dir,
no_captioning=True,
output_json=True,
json_mini=json_mini,
model_obj=YOLO_MODEL,
sr=GLOBAL_SR,
spell=None if skip_ocr else GLOBAL_SPELL,
reader=None if skip_ocr else GLOBAL_READER,
skip_ocr=skip_ocr,
skip_spell=skip_spell,
)
base_name = os.path.splitext(os.path.basename(image_path))[0]
result_dir = os.path.join(output_dir, "result")
json_file = os.path.join(result_dir, f"{base_name}.json")
if os.path.exists(json_file):
with open(json_file, "r", encoding="utf-8") as f:
return f.read()
else:
raise FileNotFoundError(f"Result file not generated: {json_file}")
def handle_action(openai_key, image, prompt):
if not openai_key: return "Error: OpenAI API Key is required for /action."
if image is None: return "Error: Please upload an image."
if not prompt: return "Error: Please provide a prompt."
try:
llm_client = OpenAI(api_key=openai_key)
image_b64 = pil_to_base64_str(image)
with tempfile.TemporaryDirectory() as temp_dir:
request_id = str(uuid.uuid4())
original_image_path = os.path.join(temp_dir, f"{request_id}.png")
yolo_updated_image_path = os.path.join(temp_dir, f"{request_id}_yolo_updated.png")
save_base64_image(image_b64, original_image_path)
image_description = run_wrapper(original_image_path, temp_dir, skip_ocr=False, skip_spell=True, json_mini=True)
with open(yolo_updated_image_path, "rb") as f:
yolo_updated_img_bytes = f.read()
_, new_b64 = preprocess_image(yolo_updated_img_bytes, threshold=2000, scale=0.5, fmt="png")
base64_image_url = f"data:image/png;base64,{new_b64}"
prompt_text = f"""You are an AI agent... (rest of your long prompt)
The user said: "{prompt}"
Description: "{image_description}" """
messages = [{"role": "user", "content": [{"type": "text", "text": prompt_text}, {"type": "image_url", "image_url": {"url": base64_image_url, "detail": "high"}}]}]
response = llm_client.chat.completions.create(model="gpt-4.1", messages=messages, temperature=0.2)
return response.choices[0].message.content.strip()
except Exception as e:
logging.error(f"Error in /action endpoint: {e}", exc_info=True)
return f"An error occurred: {e}"
def handle_analyze(image, output_style):
if image is None: return "Error: Please upload an image."
try:
image_b64 = pil_to_base64_str(image)
with tempfile.TemporaryDirectory() as temp_dir:
image_path = os.path.join(temp_dir, "image_to_analyze.png")
save_base64_image(image_b64, image_path)
is_mini = (output_style == "Mini JSON")
description_str = run_wrapper(image_path=image_path, output_dir=temp_dir, json_mini=is_mini)
parsed_json = json.loads(description_str)
return json.dumps(parsed_json, indent=2)
except Exception as e:
logging.error(f"Error in /analyze endpoint: {e}", exc_info=True)
return f"An error occurred: {e}"
def handle_analyze_yolo(image, output_style):
if image is None: return None, "Error: Please upload an image."
try:
image_b64 = pil_to_base64_str(image)
with tempfile.TemporaryDirectory() as temp_dir:
request_id = str(uuid.uuid4())
image_path = os.path.join(temp_dir, f"{request_id}.png")
yolo_image_path = os.path.join(temp_dir, f"{request_id}_yolo_updated.png")
save_base64_image(image_b64, image_path)
is_mini = (output_style == "Mini JSON")
description_str = run_wrapper(image_path=image_path, output_dir=temp_dir, json_mini=is_mini)
parsed_json = json.loads(description_str)
description_output = json.dumps(parsed_json, indent=2)
yolo_image_result = Image.open(yolo_image_path)
return yolo_image_result, description_output
except Exception as e:
logging.error(f"Error in /analyze_and_get_yolo: {e}", exc_info=True)
return None, f"An error occurred: {e}"
def handle_generate(openai_key, image, prompt):
if not openai_key: return "Error: OpenAI API Key is required for /generate."
if image is None: return "Error: Please upload an image."
if not prompt: return "Error: Please provide a prompt."
try:
llm_client = OpenAI(api_key=openai_key)
image_b64 = pil_to_base64_str(image)
with tempfile.TemporaryDirectory() as temp_dir:
request_id = str(uuid.uuid4())
original_image_path = os.path.join(temp_dir, f"{request_id}.png")
yolo_updated_image_path = os.path.join(temp_dir, f"{request_id}_yolo_updated.png")
save_base64_image(image_b64, original_image_path)
image_description = run_wrapper(image_path=original_image_path, output_dir=temp_dir, json_mini=False)
with open(yolo_updated_image_path, "rb") as f:
yolo_updated_img_bytes = f.read()
_, new_b64 = preprocess_image(yolo_updated_img_bytes, threshold=1500, scale=0.5, fmt="png")
base64_image_url = f"data:image/png;base64,{new_b64}"
messages = [
{"role": "user", "content": [
{"type": "text", "text": f'"Prompt: {prompt}"\nImage description:\n"{image_description}"'},
{"type": "image_url", "image_url": {"url": base64_image_url, "detail": "high"}}
]}
]
response = llm_client.chat.completions.create(model="gpt-4.1", messages=messages, temperature=0.2)
return response.choices[0].message.content.strip()
except Exception as e:
logging.error(f"Error in /generate endpoint: {e}", exc_info=True)
return f"An error occurred: {e}"
default_image_1 = Image.open("./res/bb_1.jpeg")
default_image_2 = Image.open("./res/mfa_1.jpeg")
def load_example_action_1(): return default_image_1, "Open and read Umico partner"
def load_example_action_2(): return default_image_2, "Sign up in the application"
def load_example_analyze_1(): return default_image_1
def load_example_analyze_2(): return default_image_2
def load_example_yolo_1(): return default_image_1
def load_example_yolo_2(): return default_image_2
def load_example_generate_1(): return default_image_1, "Generate the code for this screen for Android XML. Try to use constraint layout"
def load_example_generate_2(): return default_image_2, "Generate the code for this screen for Android XML. Try to use constraint layout"
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# Deki Automata: UI Analysis and Generation")
gr.Markdown("Provide your API keys below. The OpenAI key is only required for the 'Action' and 'Generate' tabs.")
with gr.Row():
openai_key_input = gr.Textbox(label="OpenAI API Key", placeholder="Enter your OpenAI API Key", type="password", scale=1)
with gr.Tabs():
with gr.TabItem("Action"):
gr.Markdown("### Control a device with natural language.")
with gr.Row():
image_input_action = gr.Image(type="pil", label="Upload Screen Image")
prompt_input_action = gr.Textbox(lines=2, placeholder="e.g., 'Open whatsapp and text my friend...'", label="Prompt")
action_output = gr.Textbox(label="Response Command")
action_button = gr.Button("Run Action", variant="primary")
with gr.Row():
example_action_btn1 = gr.Button("Load Example 1")
example_action_btn2 = gr.Button("Load Example 2")
with gr.TabItem("Analyze"):
gr.Markdown("### Get a structured JSON description of the UI elements.")
with gr.Row():
image_input_analyze = gr.Image(type="pil", label="Upload Screen Image")
with gr.Column():
output_style_analyze = gr.Radio(["Standard JSON", "Mini JSON"], label="Output Format", value="Standard JSON")
analyze_button = gr.Button("Analyze Image", variant="primary")
analyze_output = gr.JSON(label="JSON Description")
with gr.Row():
example_analyze_btn1 = gr.Button("Load Example 1")
example_analyze_btn2 = gr.Button("Load Example 2")
with gr.TabItem("Analyze & Get YOLO"):
gr.Markdown("### Get a JSON description and the image with detected elements.")
with gr.Row():
image_input_yolo = gr.Image(type="pil", label="Upload Screen Image")
with gr.Column():
output_style_yolo = gr.Radio(["Standard JSON", "Mini JSON"], label="Output Format", value="Standard JSON")
yolo_button = gr.Button("Analyze and Visualize", variant="primary")
with gr.Row():
yolo_image_output = gr.Image(label="YOLO Annotated Image")
description_output_yolo = gr.JSON(label="JSON Description")
with gr.Row():
example_yolo_btn1 = gr.Button("Load Example 1")
example_yolo_btn2 = gr.Button("Load Example 2")
with gr.TabItem("Generate"):
gr.Markdown("### Generate code or text based on a screenshot.")
with gr.Row():
image_input_generate = gr.Image(type="pil", label="Upload Screen Image")
prompt_input_generate = gr.Textbox(lines=2, placeholder="e.g., 'Generate the Android XML for this screen'", label="Prompt")
generate_output = gr.Code(label="Generated Output")
generate_button = gr.Button("Generate", variant="primary")
with gr.Row():
example_generate_btn1 = gr.Button("Load Example 1")
example_generate_btn2 = gr.Button("Load Example 2")
action_button.click(fn=handle_action, inputs=[openai_key_input, image_input_action, prompt_input_action], outputs=action_output)
analyze_button.click(fn=handle_analyze, inputs=[image_input_analyze, output_style_analyze], outputs=analyze_output)
yolo_button.click(fn=handle_analyze_yolo, inputs=[image_input_yolo, output_style_yolo], outputs=[yolo_image_output, description_output_yolo])
generate_button.click(fn=handle_generate, inputs=[openai_key_input, image_input_generate, prompt_input_generate], outputs=generate_output)
example_action_btn1.click(fn=load_example_action_1, outputs=[image_input_action, prompt_input_action])
example_action_btn2.click(fn=load_example_action_2, outputs=[image_input_action, prompt_input_action])
example_analyze_btn1.click(fn=load_example_analyze_1, outputs=image_input_analyze)
example_analyze_btn2.click(fn=load_example_analyze_2, outputs=image_input_analyze)
example_yolo_btn1.click(fn=load_example_yolo_1, outputs=image_input_yolo)
example_yolo_btn2.click(fn=load_example_yolo_2, outputs=image_input_yolo)
example_generate_btn1.click(fn=load_example_generate_1, outputs=[image_input_generate, prompt_input_generate])
example_generate_btn2.click(fn=load_example_generate_2, outputs=[image_input_generate, prompt_input_generate])
load_models()
demo.launch()