Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from transformers import AutoImageProcessor, AutoModelForImageClassification | |
from PIL import Image | |
import numpy as np | |
# Load the trained model | |
model_name = "aabyzov/guitar-robot-classifier" | |
model = AutoModelForImageClassification.from_pretrained(model_name) | |
processor = AutoImageProcessor.from_pretrained(model_name) | |
# Pattern descriptions | |
pattern_info = { | |
'rock': { | |
'description': 'Energetic rock strumming with strong downstrokes', | |
'bpm': '120-140', | |
'dynamics': 'Forte (loud)', | |
'technique': 'Power chords with palm muting' | |
}, | |
'folk': { | |
'description': 'Gentle folk pattern with bass note emphasis', | |
'bpm': '80-100', | |
'dynamics': 'Mezzo-forte (medium)', | |
'technique': 'Fingerstyle or light pick' | |
}, | |
'ballad': { | |
'description': 'Slow, emotional strumming for ballads', | |
'bpm': '60-80', | |
'dynamics': 'Piano (soft)', | |
'technique': 'Gentle brushing with occasional accents' | |
} | |
} | |
def predict_pattern(image): | |
"""Predict guitar strumming pattern from image""" | |
if image is None: | |
return "Please upload an image", None, None | |
# Process image | |
inputs = processor(images=image, return_tensors="pt") | |
# Get prediction | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
predicted_id = logits.argmax(-1).item() | |
# Get probabilities | |
probs = torch.nn.functional.softmax(logits, dim=-1).squeeze() | |
# Get pattern name | |
pattern = model.config.id2label[predicted_id] | |
confidence = probs[predicted_id].item() | |
# Create detailed output | |
result_text = f"**Detected Pattern:** {pattern.upper()}\n" | |
result_text += f"**Confidence:** {confidence:.1%}\n\n" | |
result_text += f"**Description:** {pattern_info[pattern]['description']}\n" | |
result_text += f"**Recommended BPM:** {pattern_info[pattern]['bpm']}\n" | |
result_text += f"**Dynamics:** {pattern_info[pattern]['dynamics']}\n" | |
result_text += f"**Technique:** {pattern_info[pattern]['technique']}" | |
# Create probability chart | |
prob_data = { | |
'Pattern': ['Rock', 'Folk', 'Ballad'], | |
'Probability': [probs[0].item(), probs[1].item(), probs[2].item()] | |
} | |
# Generate robot action preview | |
action = generate_action_preview(pattern) | |
return result_text, prob_data, action | |
def generate_action_preview(pattern): | |
"""Generate a simple visualization of robot action""" | |
# Create a simple plot showing strumming motion | |
import matplotlib.pyplot as plt | |
fig, ax = plt.subplots(figsize=(8, 4)) | |
# Generate waveform based on pattern | |
t = np.linspace(0, 4, 1000) | |
if pattern == 'rock': | |
# Fast, strong strumming | |
wave = 0.8 * np.sin(4 * np.pi * t) + 0.2 * np.sin(8 * np.pi * t) | |
elif pattern == 'folk': | |
# Moderate, smooth strumming | |
wave = 0.5 * np.sin(2 * np.pi * t) + 0.1 * np.sin(6 * np.pi * t) | |
else: # ballad | |
# Slow, gentle strumming | |
wave = 0.3 * np.sin(1 * np.pi * t) | |
ax.plot(t, wave, 'b-', linewidth=2) | |
ax.set_xlabel('Time (seconds)') | |
ax.set_ylabel('Strumming Motion') | |
ax.set_title(f'{pattern.capitalize()} Pattern - Robot Wrist Motion') | |
ax.grid(True, alpha=0.3) | |
ax.set_ylim(-1, 1) | |
plt.tight_layout() | |
return fig | |
# Create Gradio interface | |
with gr.Blocks(title="Guitar Robot Pattern Classifier") as demo: | |
gr.Markdown(""" | |
# 🎸 Guitar Robot Pattern Classifier | |
This model classifies guitar strumming patterns for the SO-100 robot arm. | |
Upload an image to detect the strumming pattern and get robot control recommendations. | |
**Model:** [aabyzov/guitar-robot-classifier](https://huggingface.co/aabyzov/guitar-robot-classifier) | |
**Dataset:** [aabyzov/guitar-robot-realistic-v1](https://huggingface.co/datasets/aabyzov/guitar-robot-realistic-v1) | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image( | |
label="Upload Image", | |
type="pil", | |
elem_id="input-image" | |
) | |
predict_btn = gr.Button("Analyze Pattern", variant="primary") | |
gr.Examples( | |
examples=[ | |
["examples/rock_example.jpg"], | |
["examples/folk_example.jpg"], | |
["examples/ballad_example.jpg"] | |
], | |
inputs=input_image, | |
label="Example Images" | |
) | |
with gr.Column(): | |
output_text = gr.Markdown(label="Analysis Results") | |
prob_plot = gr.BarPlot( | |
label="Pattern Probabilities", | |
x="Pattern", | |
y="Probability", | |
vertical=False, | |
height=200 | |
) | |
action_plot = gr.Plot(label="Robot Motion Preview") | |
predict_btn.click( | |
fn=predict_pattern, | |
inputs=input_image, | |
outputs=[output_text, prob_plot, action_plot] | |
) | |
gr.Markdown(""" | |
## How it works: | |
1. The model analyzes the image to detect guitar and robot positions | |
2. It classifies the appropriate strumming pattern (rock, folk, or ballad) | |
3. Robot control parameters are generated based on the pattern | |
## Integration: | |
```python | |
from transformers import AutoModelForImageClassification, AutoImageProcessor | |
model = AutoModelForImageClassification.from_pretrained("aabyzov/guitar-robot-classifier") | |
processor = AutoImageProcessor.from_pretrained("aabyzov/guitar-robot-classifier") | |
``` | |
Built for LeRobot Hackathon 2024 🤖 | |
""") | |
if __name__ == "__main__": | |
demo.launch() |