Spaces:
Runtime error
Runtime error
from ultralytics import YOLO | |
import supervision as sv | |
import cv2 | |
import gradio as gr | |
import os | |
import numpy as np | |
from transformers import AutoProcessor, AutoModelForCausalLM | |
import torch | |
import requests | |
from PIL import Image | |
import glob | |
import pandas as pd | |
import time | |
import subprocess | |
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-base-ft", trust_remote_code=True).to(device).eval() | |
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base-ft", trust_remote_code=True) | |
onnx_model = YOLO("models/best.onnx", task='detect') | |
def filter_detections(detections, target_class_name="mark"): | |
indices_to_keep = [i for i, class_name in enumerate(detections.data['class_name']) if | |
class_name == target_class_name] | |
filtered_xyxy = detections.xyxy[indices_to_keep] | |
filtered_confidence = detections.confidence[indices_to_keep] | |
filtered_class_id = detections.class_id[indices_to_keep] | |
filtered_class_name = detections.data['class_name'][indices_to_keep] | |
detections.xyxy = filtered_xyxy | |
detections.confidence = filtered_confidence | |
detections.class_id = filtered_class_id | |
detections.data['class_name'] = filtered_class_name | |
return detections | |
def ends_with_number(s): | |
return s[-1].isdigit() | |
def ocr(image, prompt="<OCR>"): | |
original_height, original_width = image.shape[:2] | |
inputs = processor(text=prompt, images=image, return_tensors="pt").to(device) | |
generated_ids = model.generate( | |
input_ids=inputs["input_ids"], | |
pixel_values=inputs["pixel_values"], | |
max_new_tokens=1024, | |
early_stopping=False, | |
do_sample=False, | |
num_beams=3 | |
) | |
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] | |
parsed_answer = processor.post_process_generation( | |
generated_text, | |
task=prompt, | |
# image_size=(image.width, image.height) | |
image_size=(original_width, original_height) | |
) | |
return parsed_answer | |
def parse_detection(detections): | |
parsed_rows = [] | |
for i in range(len(detections.xyxy)): | |
x_min = float(detections.xyxy[i][0]) | |
y_min = float(detections.xyxy[i][1]) | |
x_max = float(detections.xyxy[i][2]) | |
y_max = float(detections.xyxy[i][3]) | |
width = int(x_max - x_min) | |
height = int(y_max - y_min) | |
row = { | |
"top": int(y_min), | |
"left": int(x_min), | |
"width": width, | |
"height": height, | |
"class_id": "" | |
if detections.class_id is None | |
else int(detections.class_id[i]), | |
"confidence": "" | |
if detections.confidence is None | |
else float(detections.confidence[i]), | |
"tracker_id": "" | |
if detections.tracker_id is None | |
else int(detections.tracker_id[i]), | |
} | |
if hasattr(detections, "data"): | |
for key, value in detections.data.items(): | |
row[key] = ( | |
str(value[i]) | |
if hasattr(value, "__getitem__") and value.ndim != 0 | |
else str(value) | |
) | |
parsed_rows.append(row) | |
return parsed_rows | |
def cut_and_save_image(image, parsed_detections, output_dir): | |
output_path_list = [] | |
for i, det in enumerate(parsed_detections): | |
# Check if the class is 'mark' | |
if det['class_name'] == 'mark': | |
top = det['top'] | |
left = det['left'] | |
width = det['width'] | |
height = det['height'] | |
# Cut the image | |
cut_image = image[top:top + height, left:left + width] | |
# Save the image | |
output_path = f"{output_dir}/cut_image_{i}.png" | |
scaled_image = sv.scale_image(image=cut_image, scale_factor=4) | |
cv2.imwrite(output_path, scaled_image, [int(cv2.IMWRITE_JPEG_QUALITY), 500]) | |
output_path_list.append(output_path) | |
return output_path_list | |
def analysis(progress=gr.Progress()): | |
progress(0, desc="Analyzing...") | |
list_files = glob.glob("output/*.png") | |
prompt = "<OCR>" | |
results = {} | |
for filepath in progress.tqdm(list_files): | |
basename = os.path.basename(filepath) | |
image = cv2.imread(filepath) | |
start_time = time.time() | |
parsed_answer = ocr(image, prompt) | |
if not ends_with_number(parsed_answer[prompt]): | |
parsed_answer[prompt] += "1" | |
results[parsed_answer[prompt]] = results.get(parsed_answer[prompt], 0) + 1 | |
print(basename, parsed_answer[prompt]) | |
print("Time taken:", time.time() - start_time) | |
return pd.DataFrame(results.items(), columns=['Mark', 'Total']).reset_index(drop=False).rename(columns={'index': 'No.'}) | |
def inference( | |
image_path, | |
conf_threshold, | |
iou_threshold, | |
): | |
""" | |
YOLOv8 inference function | |
Args: | |
image_path: Path to the image | |
conf_threshold: Confidence threshold | |
iou_threshold: IoU threshold | |
Returns: | |
Rendered image | |
""" | |
image = cv2.imread(image_path) | |
original_height, original_width = image.shape[:2] | |
print(image.shape) | |
results = onnx_model(image, conf=conf_threshold, iou=iou_threshold)[0] | |
detections = sv.Detections.from_ultralytics(results) | |
detections = filter_detections(detections) | |
parsed_detections = parse_detection(detections) | |
output_dir = "output" | |
# Check if the output directory exists, clear all the files inside | |
if not os.path.exists(output_dir): | |
os.makedirs(output_dir) | |
else: | |
for f in os.listdir(output_dir): | |
os.remove(os.path.join(output_dir, f)) | |
output_path_list = cut_and_save_image(image, parsed_detections, output_dir) | |
box_annotator = sv.BoxAnnotator() | |
label_annotator = sv.LabelAnnotator(text_position=sv.Position.TOP_LEFT, text_thickness=1, text_padding=2) | |
annotated_image = image.copy() | |
annotated_image = box_annotator.annotate( | |
scene=annotated_image, | |
detections=detections | |
) | |
annotated_image = label_annotator.annotate(scene=annotated_image, detections=detections) | |
return annotated_image, output_path_list | |
TITLE = "<h1 style='font-size: 2.5em; text-align: center;'>Identify objects in construction design</h1>" | |
DESCRIPTION = """<p style='font-size: 1.5em; line-height: 1.6em; text-align: left;'>Welcome to the object | |
identification application. This tool allows you to upload an image, and it will identify and annotate objects within | |
the image. Additionally, you can perform OCR analysis on the detected objects.</p> | |
""" | |
CSS = """ | |
#output { | |
height: 500px; | |
overflow: auto; | |
border: 1px solid #ccc; | |
} | |
h1 { | |
text-align: center; | |
} | |
""" | |
EXAMPLES = [ | |
['examples/train1.png', 0.6, 0.25], | |
['examples/train2.png', 0.9, 0.25], | |
['examples/train3.png', 0.6, 0.25] | |
] | |
with gr.Blocks(theme=gr.themes.Soft(), css=CSS) as demo: | |
gr.HTML(TITLE) | |
gr.HTML(DESCRIPTION) | |
with gr.Tab(label="Identify objects"): | |
with gr.Row(equal_height=False): | |
input_img = gr.Image(type="filepath", label="Upload Image") | |
output_img = gr.Image(type="filepath", label="Output Image") | |
with gr.Row(): | |
with gr.Column(): | |
conf_thres = gr.Slider(minimum=0.0, maximum=1.0, value=0.6, step=0.05, label="Confidence Threshold") | |
with gr.Column(): | |
iou = gr.Slider(minimum=0.0, maximum=1.0, value=0.25, step=0.05, label="IOU Threshold") | |
with gr.Row(): | |
with gr.Column(): | |
submit_btn = gr.Button(value="Predict") | |
with gr.Column(): | |
analysis_btn = gr.Button(value="Analysis") | |
with gr.Row(): | |
output_df = gr.Dataframe(label="Results") | |
with gr.Row(): | |
with gr.Accordion("Gallery", open=False): | |
gallery = gr.Gallery(label="Detected Mark Object", columns=3) | |
submit_btn.click(inference, [input_img, conf_thres, iou], [output_img, gallery]) | |
analysis_btn.click(analysis, [], [output_df]) | |
examples = gr.Examples( | |
EXAMPLES, | |
fn=inference, | |
inputs=[input_img, conf_thres, iou], | |
outputs=[output_img, gallery], | |
cache_examples=False, | |
) | |
demo.launch(debug=True) |