sekisan-app / app.py
vumichien's picture
Rename main.py to app.py
d28da4f verified
raw
history blame
7.81 kB
from ultralytics import YOLO
import supervision as sv
import cv2
import gradio as gr
import os
import numpy as np
from transformers import AutoProcessor, AutoModelForCausalLM
import torch
import requests
from PIL import Image
import glob
import pandas as pd
import time
import subprocess
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-base-ft", trust_remote_code=True).to(device).eval()
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base-ft", trust_remote_code=True)
onnx_model = YOLO("models/best.onnx", task='detect')
def ends_with_number(s):
return s[-1].isdigit()
def ocr(image, prompt="<OCR>"):
original_height, original_width = image.shape[:2]
inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
generated_ids = model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
early_stopping=False,
do_sample=False,
num_beams=3
)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
parsed_answer = processor.post_process_generation(
generated_text,
task=prompt,
# image_size=(image.width, image.height)
image_size=(original_width, original_height)
)
return parsed_answer
def parse_detection(detections):
parsed_rows = []
for i in range(len(detections.xyxy)):
x_min = float(detections.xyxy[i][0])
y_min = float(detections.xyxy[i][1])
x_max = float(detections.xyxy[i][2])
y_max = float(detections.xyxy[i][3])
width = int(x_max - x_min)
height = int(y_max - y_min)
row = {
"top": int(y_min),
"left": int(x_min),
"width": width,
"height": height,
"class_id": ""
if detections.class_id is None
else int(detections.class_id[i]),
"confidence": ""
if detections.confidence is None
else float(detections.confidence[i]),
"tracker_id": ""
if detections.tracker_id is None
else int(detections.tracker_id[i]),
}
if hasattr(detections, "data"):
for key, value in detections.data.items():
row[key] = (
str(value[i])
if hasattr(value, "__getitem__") and value.ndim != 0
else str(value)
)
parsed_rows.append(row)
return parsed_rows
def cut_and_save_image(image, parsed_detections, output_dir):
output_path_list = []
for i, det in enumerate(parsed_detections):
# Check if the class is 'mark'
if det['class_name'] == 'mark':
top = det['top']
left = det['left']
width = det['width']
height = det['height']
# Cut the image
cut_image = image[top:top + height, left:left + width]
# Save the image
output_path = f"{output_dir}/cut_image_{i}.png"
scaled_image = sv.scale_image(image=cut_image, scale_factor=4)
cv2.imwrite(output_path, scaled_image, [int(cv2.IMWRITE_JPEG_QUALITY), 500])
output_path_list.append(output_path)
return output_path_list
def analysis(progress=gr.Progress()):
progress(0, desc="Analyzing...")
list_files = glob.glob("output/*.png")
prompt = "<OCR>"
results = {}
for filepath in progress.tqdm(list_files):
basename = os.path.basename(filepath)
image = cv2.imread(filepath)
start_time = time.time()
parsed_answer = ocr(image, prompt)
if not ends_with_number(parsed_answer[prompt]):
parsed_answer[prompt] += "1"
results[parsed_answer[prompt]] = results.get(parsed_answer[prompt], 0) + 1
print(basename, parsed_answer[prompt])
print("Time taken:", time.time() - start_time)
return pd.DataFrame(results.items(), columns=['Mark', 'Total']).reset_index(drop=False).rename(columns={'index': 'No.'})
def inference(
image_path,
conf_threshold,
iou_threshold,
):
"""
YOLOv8 inference function
Args:
image_path: Path to the image
conf_threshold: Confidence threshold
iou_threshold: IoU threshold
Returns:
Rendered image
"""
image = cv2.imread(image_path)
original_height, original_width = image.shape[:2]
print(image.shape)
results = onnx_model(image, conf=conf_threshold, iou=iou_threshold)[0]
detections = sv.Detections.from_ultralytics(results)
parsed_detections = parse_detection(detections)
output_dir = "output"
# Check if the output directory exists, clear all the files inside
if not os.path.exists(output_dir):
os.makedirs(output_dir)
else:
for f in os.listdir(output_dir):
os.remove(os.path.join(output_dir, f))
output_path_list = cut_and_save_image(image, parsed_detections, output_dir)
box_annotator = sv.BoxAnnotator()
label_annotator = sv.LabelAnnotator(text_position=sv.Position.TOP_LEFT, text_thickness=1, text_padding=2)
annotated_image = image.copy()
annotated_image = box_annotator.annotate(
scene=annotated_image,
detections=detections
)
annotated_image = label_annotator.annotate(scene=annotated_image, detections=detections)
return annotated_image, output_path_list
TITLE = "<h1 style='font-size: 2.5em; text-align: center;'>Identify objects in construction design</h1>"
DESCRIPTION = """<p style='font-size: 1.5em; line-height: 1.6em; text-align: left;'>Welcome to the object
identification application. This tool allows you to upload an image, and it will identify and annotate objects within
the image. Additionally, you can perform OCR analysis on the detected objects.</p> """
CSS = """
#output {
height: 500px;
overflow: auto;
border: 1px solid #ccc;
}
h1 {
text-align: center;
}
"""
EXAMPLES = [
['examples/train1.png', 0.6, 0.25],
['examples/train2.png', 0.9, 0.25],
['examples/train3.png', 0.6, 0.25]
]
with gr.Blocks(theme=gr.themes.Soft(), css=CSS) as demo:
gr.HTML(TITLE)
gr.HTML(DESCRIPTION)
with gr.Tab(label="Identify objects"):
with gr.Row():
input_img = gr.Image(type="filepath", label="Upload Image")
output_img = gr.Image(type="filepath", label="Output Image")
with gr.Row():
with gr.Column():
conf_thres = gr.Slider(minimum=0.0, maximum=1.0, value=0.6, step=0.05, label="Confidence Threshold")
with gr.Column():
iou = gr.Slider(minimum=0.0, maximum=1.0, value=0.25, step=0.05, label="IOU Threshold")
with gr.Row():
with gr.Column():
submit_btn = gr.Button(value="Predict")
with gr.Column():
analysis_btn = gr.Button(value="Analysis")
with gr.Row():
output_df = gr.Dataframe(label="Results")
with gr.Row():
with gr.Accordion("Gallery", open=False):
gallery = gr.Gallery(label="Detected Mark Object", columns=3)
submit_btn.click(inference, [input_img, conf_thres, iou], [output_img, gallery])
analysis_btn.click(analysis, [], [output_df])
examples = gr.Examples(
EXAMPLES,
fn=inference,
inputs=[input_img, conf_thres, iou],
outputs=[output_img, gallery],
cache_examples=False,
)
demo.launch(debug=True)