import gradio as gr import torch from transformers import AutoImageProcessor, AutoModelForImageClassification from PIL import Image import numpy as np # Load the trained model model_name = "aabyzov/guitar-robot-classifier" model = AutoModelForImageClassification.from_pretrained(model_name) processor = AutoImageProcessor.from_pretrained(model_name) # Pattern descriptions pattern_info = { 'rock': { 'description': 'Energetic rock strumming with strong downstrokes', 'bpm': '120-140', 'dynamics': 'Forte (loud)', 'technique': 'Power chords with palm muting' }, 'folk': { 'description': 'Gentle folk pattern with bass note emphasis', 'bpm': '80-100', 'dynamics': 'Mezzo-forte (medium)', 'technique': 'Fingerstyle or light pick' }, 'ballad': { 'description': 'Slow, emotional strumming for ballads', 'bpm': '60-80', 'dynamics': 'Piano (soft)', 'technique': 'Gentle brushing with occasional accents' } } def predict_pattern(image): """Predict guitar strumming pattern from image""" if image is None: return "Please upload an image", None, None # Process image inputs = processor(images=image, return_tensors="pt") # Get prediction with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits predicted_id = logits.argmax(-1).item() # Get probabilities probs = torch.nn.functional.softmax(logits, dim=-1).squeeze() # Get pattern name pattern = model.config.id2label[predicted_id] confidence = probs[predicted_id].item() # Create detailed output result_text = f"**Detected Pattern:** {pattern.upper()}\n" result_text += f"**Confidence:** {confidence:.1%}\n\n" result_text += f"**Description:** {pattern_info[pattern]['description']}\n" result_text += f"**Recommended BPM:** {pattern_info[pattern]['bpm']}\n" result_text += f"**Dynamics:** {pattern_info[pattern]['dynamics']}\n" result_text += f"**Technique:** {pattern_info[pattern]['technique']}" # Create probability chart prob_data = { 'Pattern': ['Rock', 'Folk', 'Ballad'], 'Probability': [probs[0].item(), probs[1].item(), probs[2].item()] } # Generate robot action preview action = generate_action_preview(pattern) return result_text, prob_data, action def generate_action_preview(pattern): """Generate a simple visualization of robot action""" # Create a simple plot showing strumming motion import matplotlib.pyplot as plt fig, ax = plt.subplots(figsize=(8, 4)) # Generate waveform based on pattern t = np.linspace(0, 4, 1000) if pattern == 'rock': # Fast, strong strumming wave = 0.8 * np.sin(4 * np.pi * t) + 0.2 * np.sin(8 * np.pi * t) elif pattern == 'folk': # Moderate, smooth strumming wave = 0.5 * np.sin(2 * np.pi * t) + 0.1 * np.sin(6 * np.pi * t) else: # ballad # Slow, gentle strumming wave = 0.3 * np.sin(1 * np.pi * t) ax.plot(t, wave, 'b-', linewidth=2) ax.set_xlabel('Time (seconds)') ax.set_ylabel('Strumming Motion') ax.set_title(f'{pattern.capitalize()} Pattern - Robot Wrist Motion') ax.grid(True, alpha=0.3) ax.set_ylim(-1, 1) plt.tight_layout() return fig # Create Gradio interface with gr.Blocks(title="Guitar Robot Pattern Classifier") as demo: gr.Markdown(""" # 🎸 Guitar Robot Pattern Classifier This model classifies guitar strumming patterns for the SO-100 robot arm. Upload an image to detect the strumming pattern and get robot control recommendations. **Model:** [aabyzov/guitar-robot-classifier](https://huggingface.co/aabyzov/guitar-robot-classifier) **Dataset:** [aabyzov/guitar-robot-realistic-v1](https://huggingface.co/datasets/aabyzov/guitar-robot-realistic-v1) """) with gr.Row(): with gr.Column(): input_image = gr.Image( label="Upload Image", type="pil", elem_id="input-image" ) predict_btn = gr.Button("Analyze Pattern", variant="primary") gr.Examples( examples=[ ["examples/rock_example.jpg"], ["examples/folk_example.jpg"], ["examples/ballad_example.jpg"] ], inputs=input_image, label="Example Images" ) with gr.Column(): output_text = gr.Markdown(label="Analysis Results") prob_plot = gr.BarPlot( label="Pattern Probabilities", x="Pattern", y="Probability", vertical=False, height=200 ) action_plot = gr.Plot(label="Robot Motion Preview") predict_btn.click( fn=predict_pattern, inputs=input_image, outputs=[output_text, prob_plot, action_plot] ) gr.Markdown(""" ## How it works: 1. The model analyzes the image to detect guitar and robot positions 2. It classifies the appropriate strumming pattern (rock, folk, or ballad) 3. Robot control parameters are generated based on the pattern ## Integration: ```python from transformers import AutoModelForImageClassification, AutoImageProcessor model = AutoModelForImageClassification.from_pretrained("aabyzov/guitar-robot-classifier") processor = AutoImageProcessor.from_pretrained("aabyzov/guitar-robot-classifier") ``` Built for LeRobot Hackathon 2024 🤖 """) if __name__ == "__main__": demo.launch()