Spaces:
Sleeping
Sleeping
import gradio as gr | |
import json | |
import os | |
from pathlib import Path | |
from PIL import Image | |
import shutil | |
from ultralytics import YOLO | |
import requests | |
# Directory and file configurations | |
MODELS_DIR = "models" | |
MODELS_INFO_FILE = "models_info.json" | |
TEMP_DIR = "temp" | |
OUTPUT_DIR = "outputs" | |
# New files for storing ratings, detections, and recommended datasets | |
RATINGS_FILE = "ratings.json" | |
DETECTIONS_FILE = "detections.json" | |
RECOMMENDED_DATASETS_FILE = "recommended_datasets.json" | |
def download_file(url, dest_path): | |
""" | |
Download a file from a URL to the destination path. | |
Args: | |
url (str): The URL to download from. | |
dest_path (str): The local path to save the file. | |
Returns: | |
bool: True if download succeeded, False otherwise. | |
""" | |
try: | |
response = requests.get(url, stream=True) | |
response.raise_for_status() | |
with open(dest_path, 'wb') as f: | |
for chunk in response.iter_content(chunk_size=8192): | |
f.write(chunk) | |
print(f"Downloaded {url} to {dest_path}.") | |
return True | |
except Exception as e: | |
print(f"Failed to download {url}. Error: {e}") | |
return False | |
def load_models(models_dir=MODELS_DIR, info_file=MODELS_INFO_FILE): | |
""" | |
Load YOLO models and their information from the specified directory and JSON file. | |
Downloads models if they are not already present. | |
Args: | |
models_dir (str): Path to the models directory. | |
info_file (str): Path to the JSON file containing model info. | |
Returns: | |
dict: A dictionary of models and their associated information. | |
""" | |
with open(info_file, 'r') as f: | |
models_info = json.load(f) | |
models = {} | |
for model_info in models_info: | |
model_name = model_info['model_name'] | |
display_name = model_info.get('display_name', model_name) | |
model_dir = os.path.join(models_dir, model_name) | |
os.makedirs(model_dir, exist_ok=True) | |
model_path = os.path.join(model_dir, f"{model_name}.pt") | |
download_url = model_info['download_url'] | |
if not os.path.isfile(model_path): | |
print(f"Model '{display_name}' not found locally. Downloading from {download_url}...") | |
success = download_file(download_url, model_path) | |
if not success: | |
print(f"Skipping model '{display_name}' due to download failure.") | |
continue | |
try: | |
model = YOLO(model_path) | |
models[model_name] = { | |
'display_name': display_name, | |
'model': model, | |
'info': model_info | |
} | |
print(f"Loaded model '{display_name}' from '{model_path}'.") | |
except Exception as e: | |
print(f"Error loading model '{display_name}': {e}") | |
return models | |
def get_model_info(model_info, ratings_info): | |
""" | |
Retrieve formatted model information for display, including average rating. | |
Args: | |
model_info (dict): The model's information dictionary. | |
ratings_info (dict): The ratings information for the model. | |
Returns: | |
str: A formatted string containing model details and average rating. | |
""" | |
info = model_info | |
class_ids = info.get('class_ids', {}) | |
class_image_counts = info.get('class_image_counts', {}) | |
datasets_used = info.get('datasets_used', []) | |
class_ids_formatted = "\n".join([f"{cid}: {cname}" for cid, cname in class_ids.items()]) | |
class_image_counts_formatted = "\n".join([f"{cname}: {count}" for cname, count in class_image_counts.items()]) | |
datasets_used_formatted = "\n".join([f"- {dataset}" for dataset in datasets_used]) | |
# Calculate average rating | |
total_rating = ratings_info.get('total', 0) | |
count_rating = ratings_info.get('count', 0) | |
average_rating = (total_rating / count_rating) if count_rating > 0 else "No ratings yet" | |
info_text = ( | |
f"**{info.get('display_name', 'Model Name')}**\n\n" | |
f"**Architecture:** {info.get('architecture', 'N/A')}\n\n" | |
f"**Training Epochs:** {info.get('training_epochs', 'N/A')}\n\n" | |
f"**Batch Size:** {info.get('batch_size', 'N/A')}\n\n" | |
f"**Optimizer:** {info.get('optimizer', 'N/A')}\n\n" | |
f"**Learning Rate:** {info.get('learning_rate', 'N/A')}\n\n" | |
f"**Data Augmentation Level:** {info.get('data_augmentation_level', 'N/A')}\n\n" | |
f"**[email protected]:** {info.get('mAP_score', 'N/A')}\n\n" | |
f"**Number of Images Trained On:** {info.get('num_images', 'N/A')}\n\n" | |
f"**Class IDs:**\n{class_ids_formatted}\n\n" | |
f"**Datasets Used:**\n{datasets_used_formatted}\n\n" | |
f"**Class Image Counts:**\n{class_image_counts_formatted}\n\n" | |
f"**Average Rating:** {average_rating} β" | |
) | |
return info_text | |
def predict_image(model_name, image, confidence, models): | |
""" | |
Perform prediction on an uploaded image using the selected YOLO model. | |
Args: | |
model_name (str): The name of the selected model. | |
image (PIL.Image.Image): The uploaded image. | |
confidence (float): The confidence threshold for detections. | |
models (dict): The dictionary containing models and their info. | |
Returns: | |
tuple: A status message, the processed image, and the path to the output image. | |
""" | |
model_entry = models.get(model_name, {}) | |
model = model_entry.get('model', None) | |
if not model: | |
return "Error: Model not found.", None, None | |
try: | |
os.makedirs(TEMP_DIR, exist_ok=True) | |
os.makedirs(OUTPUT_DIR, exist_ok=True) | |
input_image_path = os.path.join(TEMP_DIR, f"{model_name}_input_image.jpg") | |
image.save(input_image_path) | |
results = model(input_image_path, save=True, save_txt=False, conf=confidence) | |
latest_run = sorted(Path("runs/detect").glob("predict*"), key=os.path.getmtime)[-1] | |
output_image_path = os.path.join(latest_run, Path(input_image_path).name) | |
if not os.path.isfile(output_image_path): | |
output_image_path = results[0].save()[0] | |
final_output_path = os.path.join(OUTPUT_DIR, f"{model_name}_output_image.jpg") | |
shutil.copy(output_image_path, final_output_path) | |
output_image = Image.open(final_output_path) | |
# Calculate number of detections | |
detections = len(results[0].boxes) | |
return f"β Prediction completed successfully. **Detections:** {detections}", output_image, final_output_path | |
except Exception as e: | |
return f"β Error during prediction: {str(e)}", None, None | |
def load_or_initialize_json(file_path, default_data): | |
""" | |
Load JSON data from a file or initialize it with default data if the file doesn't exist. | |
Args: | |
file_path (str): Path to the JSON file. | |
default_data (dict or list): Default data to initialize if file doesn't exist. | |
Returns: | |
dict or list: The loaded or initialized data. | |
""" | |
if os.path.isfile(file_path): | |
with open(file_path, 'r') as f: | |
return json.load(f) | |
else: | |
with open(file_path, 'w') as f: | |
json.dump(default_data, f, indent=4) | |
return default_data | |
def save_json(file_path, data): | |
""" | |
Save data to a JSON file. | |
Args: | |
file_path (str): Path to the JSON file. | |
data (dict or list): Data to save. | |
""" | |
with open(file_path, 'w') as f: | |
json.dump(data, f, indent=4) | |
def is_valid_roboflow_url(url): | |
""" | |
Validate if the provided URL is a Roboflow URL. | |
Args: | |
url (str): The URL to validate. | |
Returns: | |
bool: True if valid, False otherwise. | |
""" | |
return url.startswith("https://roboflow.com/") or url.startswith("http://roboflow.com/") | |
def get_top_model(detections_per_model, models): | |
""" | |
Determine the top model based on the highest number of detections. | |
Args: | |
detections_per_model (dict): Dictionary with model names as keys and detection counts as values. | |
models (dict): Dictionary of loaded models. | |
Returns: | |
str: The display name of the top model or a message if no detections exist. | |
""" | |
if not detections_per_model: | |
return "No detections yet." | |
top_model_name = max(detections_per_model, key=detections_per_model.get) | |
top_model_display = models[top_model_name]['display_name'] | |
top_detections = detections_per_model[top_model_name] | |
return f"**Top Model:** {top_model_display} with **{top_detections}** detections." | |
def main(): | |
# Load models | |
models = load_models() | |
if not models: | |
print("No models loaded. Please check your models_info.json and model URLs.") | |
return | |
# Load or initialize ratings | |
ratings_data = load_or_initialize_json(RATINGS_FILE, {}) | |
# Initialize ratings for each model if not present | |
for model_name in models: | |
if model_name not in ratings_data: | |
ratings_data[model_name] = {"total": 0, "count": 0} | |
save_json(RATINGS_FILE, ratings_data) | |
# Load or initialize detections counter | |
detections_data = load_or_initialize_json(DETECTIONS_FILE, {"total_detections": 0, "detections_per_model": {}}) | |
# Load or initialize recommended datasets | |
recommended_datasets = load_or_initialize_json(RECOMMENDED_DATASETS_FILE, []) | |
with gr.Blocks() as demo: | |
gr.Markdown("# π§ͺ YOLOv11 Model Tester") | |
gr.Markdown( | |
""" | |
Upload images to test different YOLOv11 models. Select a model from the dropdown to see its details. | |
""" | |
) | |
# Display total detections counter and top model | |
with gr.Row(): | |
detections_counter = gr.Markdown( | |
f"**Total Detections Across All Users:** {detections_data.get('total_detections', 0)}" | |
) | |
top_model_display = gr.Markdown( | |
get_top_model(detections_data.get('detections_per_model', {}), models) | |
) | |
with gr.Row(): | |
model_dropdown = gr.Dropdown( | |
choices=[models[m]['display_name'] for m in models], | |
label="Select Model", | |
value=None | |
) | |
model_info = gr.Markdown("**Model Information will appear here.**") | |
display_to_name = {models[m]['display_name']: m for m in models} | |
def update_model_info(selected_display_name): | |
if not selected_display_name: | |
return "Please select a model." | |
model_name = display_to_name.get(selected_display_name) | |
if not model_name: | |
return "Model information not available." | |
model_entry = models[model_name]['info'] | |
ratings_info = ratings_data.get(model_name, {"total": 0, "count": 0}) | |
return get_model_info(model_entry, ratings_info) | |
model_dropdown.change( | |
fn=update_model_info, | |
inputs=model_dropdown, | |
outputs=model_info | |
) | |
with gr.Row(): | |
confidence_slider = gr.Slider( | |
minimum=0.0, | |
maximum=1.0, | |
step=0.01, | |
value=0.25, | |
label="Confidence Threshold", | |
info="Adjust the minimum confidence required for detections to be displayed." | |
) | |
with gr.Tab("πΌοΈ Image"): | |
with gr.Column(): | |
image_input = gr.Image( | |
type='pil', | |
label="Upload Image for Prediction" | |
) | |
image_predict_btn = gr.Button("π Predict on Image") | |
image_status = gr.Markdown("**Status will appear here.**") | |
image_output = gr.Image(label="Predicted Image") | |
image_download_btn = gr.File(label="β¬οΈ Download Predicted Image") | |
def process_image(selected_display_name, image, confidence): | |
if not selected_display_name: | |
return "β Please select a model.", None, None | |
model_name = display_to_name.get(selected_display_name) | |
status, output_img, output_path = predict_image(model_name, image, confidence, models) | |
# Extract number of detections from the status message | |
detections = 0 | |
if "Detections:" in status: | |
try: | |
detections = int(status.split("Detections:")[1].strip()) | |
except: | |
pass | |
# Update detections counter | |
try: | |
detections_data['total_detections'] += detections | |
if model_name in detections_data['detections_per_model']: | |
detections_data['detections_per_model'][model_name] += detections | |
else: | |
detections_data['detections_per_model'][model_name] = detections | |
save_json(DETECTIONS_FILE, detections_data) | |
except Exception as e: | |
print(f"Error updating detections counter: {e}") | |
# Update detections and top model display | |
detections_counter.value = f"**Total Detections Across All Users:** {detections_data.get('total_detections', 0)}" | |
top_model_display.value = get_top_model(detections_data.get('detections_per_model', {}), models) | |
return status, output_img, output_path | |
image_predict_btn.click( | |
fn=process_image, | |
inputs=[model_dropdown, image_input, confidence_slider], | |
outputs=[image_status, image_output, image_download_btn] | |
) | |
with gr.Tab("β Rate Model"): | |
with gr.Column(): | |
selected_model = gr.Dropdown( | |
choices=[models[m]['display_name'] for m in models], | |
label="Select Model to Rate", | |
value=None | |
) | |
rating = gr.Slider( | |
minimum=1, | |
maximum=5, | |
step=1, | |
label="Rate the Model (1-5 Stars)", | |
info="Select a star rating between 1 and 5." | |
) | |
submit_rating_btn = gr.Button("Submit Rating") | |
rating_status = gr.Markdown("**Your rating will be submitted here.**") | |
def submit_rating(selected_display_name, user_rating): | |
if not selected_display_name: | |
return "β Please select a model to rate." | |
if not user_rating: | |
return "β Please provide a rating." | |
model_name = display_to_name.get(selected_display_name) | |
if not model_name: | |
return "β Invalid model selected." | |
# Update ratings data | |
ratings_info = ratings_data.get(model_name, {"total": 0, "count": 0}) | |
ratings_info['total'] += user_rating | |
ratings_info['count'] += 1 | |
ratings_data[model_name] = ratings_info | |
save_json(RATINGS_FILE, ratings_data) | |
# Update model info display if the rated model is currently selected | |
if model_dropdown.value == selected_display_name: | |
updated_info = get_model_info(models[model_name]['info'], ratings_info) | |
model_info.value = updated_info | |
average = (ratings_info['total'] / ratings_info['count']) | |
return f"β Thank you for rating! Current Average Rating: {average:.2f} β" | |
submit_rating_btn.click( | |
fn=submit_rating, | |
inputs=[selected_model, rating], | |
outputs=rating_status | |
) | |
with gr.Tab("π‘ Recommend Dataset"): | |
with gr.Column(): | |
dataset_name = gr.Textbox( | |
label="Dataset Name", | |
placeholder="Enter the name of the dataset" | |
) | |
dataset_url = gr.Textbox( | |
label="Dataset URL", | |
placeholder="Enter the Roboflow dataset URL" | |
) | |
recommend_btn = gr.Button("Recommend Dataset") | |
recommend_status = gr.Markdown("**Your recommendation status will appear here.**") | |
def recommend_dataset(name, url): | |
if not name or not url: | |
return "β Please provide both the dataset name and URL." | |
if not is_valid_roboflow_url(url): | |
return "β Invalid URL. Please provide a valid Roboflow dataset URL." | |
# Check for duplicates | |
for dataset in recommended_datasets: | |
if dataset['name'].lower() == name.lower() or dataset['url'] == url: | |
return "β This dataset has already been recommended." | |
# Add to recommended datasets | |
recommended_datasets.append({"name": name, "url": url}) | |
save_json(RECOMMENDED_DATASETS_FILE, recommended_datasets) | |
return f"β Thank you for recommending the dataset **{name}**!" | |
recommend_btn.click( | |
fn=recommend_dataset, | |
inputs=[dataset_name, dataset_url], | |
outputs=recommend_status | |
) | |
with gr.Tab("π Recommended Datasets"): | |
with gr.Column(): | |
recommended_display = gr.Markdown("### Recommended Roboflow Datasets\n") | |
def display_recommended_datasets(): | |
if not recommended_datasets: | |
return "No datasets have been recommended yet." | |
dataset_md = "\n".join([f"- [{dataset['name']}]({dataset['url']})" for dataset in recommended_datasets]) | |
return dataset_md | |
# Display the recommended datasets | |
recommended_display.value = display_recommended_datasets() | |
with gr.Tab("π Top Model"): | |
with gr.Column(): | |
top_model_md = gr.Markdown(get_top_model(detections_data.get('detections_per_model', {}), models)) | |
gr.Markdown( | |
""" | |
--- | |
**Note:** Models are downloaded from GitHub upon first use. Ensure that you have a stable internet connection and sufficient storage space. | |
""" | |
) | |
demo.launch() | |
if __name__ == "__main__": | |
main() | |