rolo-models / app.py
wuhp's picture
Update app.py
24fabae verified
raw
history blame
18.5 kB
import gradio as gr
import json
import os
from pathlib import Path
from PIL import Image
import shutil
from ultralytics import YOLO
import requests
# Directory and file configurations
MODELS_DIR = "models"
MODELS_INFO_FILE = "models_info.json"
TEMP_DIR = "temp"
OUTPUT_DIR = "outputs"
# New files for storing ratings, detections, and recommended datasets
RATINGS_FILE = "ratings.json"
DETECTIONS_FILE = "detections.json"
RECOMMENDED_DATASETS_FILE = "recommended_datasets.json"
def download_file(url, dest_path):
"""
Download a file from a URL to the destination path.
Args:
url (str): The URL to download from.
dest_path (str): The local path to save the file.
Returns:
bool: True if download succeeded, False otherwise.
"""
try:
response = requests.get(url, stream=True)
response.raise_for_status()
with open(dest_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
print(f"Downloaded {url} to {dest_path}.")
return True
except Exception as e:
print(f"Failed to download {url}. Error: {e}")
return False
def load_models(models_dir=MODELS_DIR, info_file=MODELS_INFO_FILE):
"""
Load YOLO models and their information from the specified directory and JSON file.
Downloads models if they are not already present.
Args:
models_dir (str): Path to the models directory.
info_file (str): Path to the JSON file containing model info.
Returns:
dict: A dictionary of models and their associated information.
"""
with open(info_file, 'r') as f:
models_info = json.load(f)
models = {}
for model_info in models_info:
model_name = model_info['model_name']
display_name = model_info.get('display_name', model_name)
model_dir = os.path.join(models_dir, model_name)
os.makedirs(model_dir, exist_ok=True)
model_path = os.path.join(model_dir, f"{model_name}.pt")
download_url = model_info['download_url']
if not os.path.isfile(model_path):
print(f"Model '{display_name}' not found locally. Downloading from {download_url}...")
success = download_file(download_url, model_path)
if not success:
print(f"Skipping model '{display_name}' due to download failure.")
continue
try:
model = YOLO(model_path)
models[model_name] = {
'display_name': display_name,
'model': model,
'info': model_info
}
print(f"Loaded model '{display_name}' from '{model_path}'.")
except Exception as e:
print(f"Error loading model '{display_name}': {e}")
return models
def get_model_info(model_info, ratings_info):
"""
Retrieve formatted model information for display, including average rating.
Args:
model_info (dict): The model's information dictionary.
ratings_info (dict): The ratings information for the model.
Returns:
str: A formatted string containing model details and average rating.
"""
info = model_info
class_ids = info.get('class_ids', {})
class_image_counts = info.get('class_image_counts', {})
datasets_used = info.get('datasets_used', [])
class_ids_formatted = "\n".join([f"{cid}: {cname}" for cid, cname in class_ids.items()])
class_image_counts_formatted = "\n".join([f"{cname}: {count}" for cname, count in class_image_counts.items()])
datasets_used_formatted = "\n".join([f"- {dataset}" for dataset in datasets_used])
# Calculate average rating
total_rating = ratings_info.get('total', 0)
count_rating = ratings_info.get('count', 0)
average_rating = (total_rating / count_rating) if count_rating > 0 else "No ratings yet"
info_text = (
f"**{info.get('display_name', 'Model Name')}**\n\n"
f"**Architecture:** {info.get('architecture', 'N/A')}\n\n"
f"**Training Epochs:** {info.get('training_epochs', 'N/A')}\n\n"
f"**Batch Size:** {info.get('batch_size', 'N/A')}\n\n"
f"**Optimizer:** {info.get('optimizer', 'N/A')}\n\n"
f"**Learning Rate:** {info.get('learning_rate', 'N/A')}\n\n"
f"**Data Augmentation Level:** {info.get('data_augmentation_level', 'N/A')}\n\n"
f"**[email protected]:** {info.get('mAP_score', 'N/A')}\n\n"
f"**Number of Images Trained On:** {info.get('num_images', 'N/A')}\n\n"
f"**Class IDs:**\n{class_ids_formatted}\n\n"
f"**Datasets Used:**\n{datasets_used_formatted}\n\n"
f"**Class Image Counts:**\n{class_image_counts_formatted}\n\n"
f"**Average Rating:** {average_rating} ⭐"
)
return info_text
def predict_image(model_name, image, confidence, models):
"""
Perform prediction on an uploaded image using the selected YOLO model.
Args:
model_name (str): The name of the selected model.
image (PIL.Image.Image): The uploaded image.
confidence (float): The confidence threshold for detections.
models (dict): The dictionary containing models and their info.
Returns:
tuple: A status message, the processed image, and the path to the output image.
"""
model_entry = models.get(model_name, {})
model = model_entry.get('model', None)
if not model:
return "Error: Model not found.", None, None
try:
os.makedirs(TEMP_DIR, exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)
input_image_path = os.path.join(TEMP_DIR, f"{model_name}_input_image.jpg")
image.save(input_image_path)
results = model(input_image_path, save=True, save_txt=False, conf=confidence)
latest_run = sorted(Path("runs/detect").glob("predict*"), key=os.path.getmtime)[-1]
output_image_path = os.path.join(latest_run, Path(input_image_path).name)
if not os.path.isfile(output_image_path):
output_image_path = results[0].save()[0]
final_output_path = os.path.join(OUTPUT_DIR, f"{model_name}_output_image.jpg")
shutil.copy(output_image_path, final_output_path)
output_image = Image.open(final_output_path)
# Calculate number of detections
detections = len(results[0].boxes)
return f"βœ… Prediction completed successfully. **Detections:** {detections}", output_image, final_output_path
except Exception as e:
return f"❌ Error during prediction: {str(e)}", None, None
def load_or_initialize_json(file_path, default_data):
"""
Load JSON data from a file or initialize it with default data if the file doesn't exist.
Args:
file_path (str): Path to the JSON file.
default_data (dict or list): Default data to initialize if file doesn't exist.
Returns:
dict or list: The loaded or initialized data.
"""
if os.path.isfile(file_path):
with open(file_path, 'r') as f:
return json.load(f)
else:
with open(file_path, 'w') as f:
json.dump(default_data, f, indent=4)
return default_data
def save_json(file_path, data):
"""
Save data to a JSON file.
Args:
file_path (str): Path to the JSON file.
data (dict or list): Data to save.
"""
with open(file_path, 'w') as f:
json.dump(data, f, indent=4)
def is_valid_roboflow_url(url):
"""
Validate if the provided URL is a Roboflow URL.
Args:
url (str): The URL to validate.
Returns:
bool: True if valid, False otherwise.
"""
return url.startswith("https://roboflow.com/") or url.startswith("http://roboflow.com/")
def get_top_model(detections_per_model, models):
"""
Determine the top model based on the highest number of detections.
Args:
detections_per_model (dict): Dictionary with model names as keys and detection counts as values.
models (dict): Dictionary of loaded models.
Returns:
str: The display name of the top model or a message if no detections exist.
"""
if not detections_per_model:
return "No detections yet."
top_model_name = max(detections_per_model, key=detections_per_model.get)
top_model_display = models[top_model_name]['display_name']
top_detections = detections_per_model[top_model_name]
return f"**Top Model:** {top_model_display} with **{top_detections}** detections."
def main():
# Load models
models = load_models()
if not models:
print("No models loaded. Please check your models_info.json and model URLs.")
return
# Load or initialize ratings
ratings_data = load_or_initialize_json(RATINGS_FILE, {})
# Initialize ratings for each model if not present
for model_name in models:
if model_name not in ratings_data:
ratings_data[model_name] = {"total": 0, "count": 0}
save_json(RATINGS_FILE, ratings_data)
# Load or initialize detections counter
detections_data = load_or_initialize_json(DETECTIONS_FILE, {"total_detections": 0, "detections_per_model": {}})
# Load or initialize recommended datasets
recommended_datasets = load_or_initialize_json(RECOMMENDED_DATASETS_FILE, [])
with gr.Blocks() as demo:
gr.Markdown("# πŸ§ͺ YOLOv11 Model Tester")
gr.Markdown(
"""
Upload images to test different YOLOv11 models. Select a model from the dropdown to see its details.
"""
)
# Display total detections counter and top model
with gr.Row():
detections_counter = gr.Markdown(
f"**Total Detections Across All Users:** {detections_data.get('total_detections', 0)}"
)
top_model_display = gr.Markdown(
get_top_model(detections_data.get('detections_per_model', {}), models)
)
with gr.Row():
model_dropdown = gr.Dropdown(
choices=[models[m]['display_name'] for m in models],
label="Select Model",
value=None
)
model_info = gr.Markdown("**Model Information will appear here.**")
display_to_name = {models[m]['display_name']: m for m in models}
def update_model_info(selected_display_name):
if not selected_display_name:
return "Please select a model."
model_name = display_to_name.get(selected_display_name)
if not model_name:
return "Model information not available."
model_entry = models[model_name]['info']
ratings_info = ratings_data.get(model_name, {"total": 0, "count": 0})
return get_model_info(model_entry, ratings_info)
model_dropdown.change(
fn=update_model_info,
inputs=model_dropdown,
outputs=model_info
)
with gr.Row():
confidence_slider = gr.Slider(
minimum=0.0,
maximum=1.0,
step=0.01,
value=0.25,
label="Confidence Threshold",
info="Adjust the minimum confidence required for detections to be displayed."
)
with gr.Tab("πŸ–ΌοΈ Image"):
with gr.Column():
image_input = gr.Image(
type='pil',
label="Upload Image for Prediction"
)
image_predict_btn = gr.Button("πŸ” Predict on Image")
image_status = gr.Markdown("**Status will appear here.**")
image_output = gr.Image(label="Predicted Image")
image_download_btn = gr.File(label="⬇️ Download Predicted Image")
def process_image(selected_display_name, image, confidence):
if not selected_display_name:
return "❌ Please select a model.", None, None
model_name = display_to_name.get(selected_display_name)
status, output_img, output_path = predict_image(model_name, image, confidence, models)
# Extract number of detections from the status message
detections = 0
if "Detections:" in status:
try:
detections = int(status.split("Detections:")[1].strip())
except:
pass
# Update detections counter
try:
detections_data['total_detections'] += detections
if model_name in detections_data['detections_per_model']:
detections_data['detections_per_model'][model_name] += detections
else:
detections_data['detections_per_model'][model_name] = detections
save_json(DETECTIONS_FILE, detections_data)
except Exception as e:
print(f"Error updating detections counter: {e}")
# Update detections and top model display
detections_counter.value = f"**Total Detections Across All Users:** {detections_data.get('total_detections', 0)}"
top_model_display.value = get_top_model(detections_data.get('detections_per_model', {}), models)
return status, output_img, output_path
image_predict_btn.click(
fn=process_image,
inputs=[model_dropdown, image_input, confidence_slider],
outputs=[image_status, image_output, image_download_btn]
)
with gr.Tab("⭐ Rate Model"):
with gr.Column():
selected_model = gr.Dropdown(
choices=[models[m]['display_name'] for m in models],
label="Select Model to Rate",
value=None
)
rating = gr.Slider(
minimum=1,
maximum=5,
step=1,
label="Rate the Model (1-5 Stars)",
info="Select a star rating between 1 and 5."
)
submit_rating_btn = gr.Button("Submit Rating")
rating_status = gr.Markdown("**Your rating will be submitted here.**")
def submit_rating(selected_display_name, user_rating):
if not selected_display_name:
return "❌ Please select a model to rate."
if not user_rating:
return "❌ Please provide a rating."
model_name = display_to_name.get(selected_display_name)
if not model_name:
return "❌ Invalid model selected."
# Update ratings data
ratings_info = ratings_data.get(model_name, {"total": 0, "count": 0})
ratings_info['total'] += user_rating
ratings_info['count'] += 1
ratings_data[model_name] = ratings_info
save_json(RATINGS_FILE, ratings_data)
# Update model info display if the rated model is currently selected
if model_dropdown.value == selected_display_name:
updated_info = get_model_info(models[model_name]['info'], ratings_info)
model_info.value = updated_info
average = (ratings_info['total'] / ratings_info['count'])
return f"βœ… Thank you for rating! Current Average Rating: {average:.2f} ⭐"
submit_rating_btn.click(
fn=submit_rating,
inputs=[selected_model, rating],
outputs=rating_status
)
with gr.Tab("πŸ’‘ Recommend Dataset"):
with gr.Column():
dataset_name = gr.Textbox(
label="Dataset Name",
placeholder="Enter the name of the dataset"
)
dataset_url = gr.Textbox(
label="Dataset URL",
placeholder="Enter the Roboflow dataset URL"
)
recommend_btn = gr.Button("Recommend Dataset")
recommend_status = gr.Markdown("**Your recommendation status will appear here.**")
def recommend_dataset(name, url):
if not name or not url:
return "❌ Please provide both the dataset name and URL."
if not is_valid_roboflow_url(url):
return "❌ Invalid URL. Please provide a valid Roboflow dataset URL."
# Check for duplicates
for dataset in recommended_datasets:
if dataset['name'].lower() == name.lower() or dataset['url'] == url:
return "❌ This dataset has already been recommended."
# Add to recommended datasets
recommended_datasets.append({"name": name, "url": url})
save_json(RECOMMENDED_DATASETS_FILE, recommended_datasets)
return f"βœ… Thank you for recommending the dataset **{name}**!"
recommend_btn.click(
fn=recommend_dataset,
inputs=[dataset_name, dataset_url],
outputs=recommend_status
)
with gr.Tab("πŸ“„ Recommended Datasets"):
with gr.Column():
recommended_display = gr.Markdown("### Recommended Roboflow Datasets\n")
def display_recommended_datasets():
if not recommended_datasets:
return "No datasets have been recommended yet."
dataset_md = "\n".join([f"- [{dataset['name']}]({dataset['url']})" for dataset in recommended_datasets])
return dataset_md
# Display the recommended datasets
recommended_display.value = display_recommended_datasets()
with gr.Tab("πŸ† Top Model"):
with gr.Column():
top_model_md = gr.Markdown(get_top_model(detections_data.get('detections_per_model', {}), models))
gr.Markdown(
"""
---
**Note:** Models are downloaded from GitHub upon first use. Ensure that you have a stable internet connection and sufficient storage space.
"""
)
demo.launch()
if __name__ == "__main__":
main()