Anton Abyzov
Add guitar robot pattern classifier app
c9a2bfc
import gradio as gr
import torch
from transformers import AutoImageProcessor, AutoModelForImageClassification
from PIL import Image
import numpy as np
# Load the trained model
model_name = "aabyzov/guitar-robot-classifier"
model = AutoModelForImageClassification.from_pretrained(model_name)
processor = AutoImageProcessor.from_pretrained(model_name)
# Pattern descriptions
pattern_info = {
'rock': {
'description': 'Energetic rock strumming with strong downstrokes',
'bpm': '120-140',
'dynamics': 'Forte (loud)',
'technique': 'Power chords with palm muting'
},
'folk': {
'description': 'Gentle folk pattern with bass note emphasis',
'bpm': '80-100',
'dynamics': 'Mezzo-forte (medium)',
'technique': 'Fingerstyle or light pick'
},
'ballad': {
'description': 'Slow, emotional strumming for ballads',
'bpm': '60-80',
'dynamics': 'Piano (soft)',
'technique': 'Gentle brushing with occasional accents'
}
}
def predict_pattern(image):
"""Predict guitar strumming pattern from image"""
if image is None:
return "Please upload an image", None, None
# Process image
inputs = processor(images=image, return_tensors="pt")
# Get prediction
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_id = logits.argmax(-1).item()
# Get probabilities
probs = torch.nn.functional.softmax(logits, dim=-1).squeeze()
# Get pattern name
pattern = model.config.id2label[predicted_id]
confidence = probs[predicted_id].item()
# Create detailed output
result_text = f"**Detected Pattern:** {pattern.upper()}\n"
result_text += f"**Confidence:** {confidence:.1%}\n\n"
result_text += f"**Description:** {pattern_info[pattern]['description']}\n"
result_text += f"**Recommended BPM:** {pattern_info[pattern]['bpm']}\n"
result_text += f"**Dynamics:** {pattern_info[pattern]['dynamics']}\n"
result_text += f"**Technique:** {pattern_info[pattern]['technique']}"
# Create probability chart
prob_data = {
'Pattern': ['Rock', 'Folk', 'Ballad'],
'Probability': [probs[0].item(), probs[1].item(), probs[2].item()]
}
# Generate robot action preview
action = generate_action_preview(pattern)
return result_text, prob_data, action
def generate_action_preview(pattern):
"""Generate a simple visualization of robot action"""
# Create a simple plot showing strumming motion
import matplotlib.pyplot as plt
fig, ax = plt.subplots(figsize=(8, 4))
# Generate waveform based on pattern
t = np.linspace(0, 4, 1000)
if pattern == 'rock':
# Fast, strong strumming
wave = 0.8 * np.sin(4 * np.pi * t) + 0.2 * np.sin(8 * np.pi * t)
elif pattern == 'folk':
# Moderate, smooth strumming
wave = 0.5 * np.sin(2 * np.pi * t) + 0.1 * np.sin(6 * np.pi * t)
else: # ballad
# Slow, gentle strumming
wave = 0.3 * np.sin(1 * np.pi * t)
ax.plot(t, wave, 'b-', linewidth=2)
ax.set_xlabel('Time (seconds)')
ax.set_ylabel('Strumming Motion')
ax.set_title(f'{pattern.capitalize()} Pattern - Robot Wrist Motion')
ax.grid(True, alpha=0.3)
ax.set_ylim(-1, 1)
plt.tight_layout()
return fig
# Create Gradio interface
with gr.Blocks(title="Guitar Robot Pattern Classifier") as demo:
gr.Markdown("""
# 🎸 Guitar Robot Pattern Classifier
This model classifies guitar strumming patterns for the SO-100 robot arm.
Upload an image to detect the strumming pattern and get robot control recommendations.
**Model:** [aabyzov/guitar-robot-classifier](https://huggingface.co/aabyzov/guitar-robot-classifier)
**Dataset:** [aabyzov/guitar-robot-realistic-v1](https://huggingface.co/datasets/aabyzov/guitar-robot-realistic-v1)
""")
with gr.Row():
with gr.Column():
input_image = gr.Image(
label="Upload Image",
type="pil",
elem_id="input-image"
)
predict_btn = gr.Button("Analyze Pattern", variant="primary")
gr.Examples(
examples=[
["examples/rock_example.jpg"],
["examples/folk_example.jpg"],
["examples/ballad_example.jpg"]
],
inputs=input_image,
label="Example Images"
)
with gr.Column():
output_text = gr.Markdown(label="Analysis Results")
prob_plot = gr.BarPlot(
label="Pattern Probabilities",
x="Pattern",
y="Probability",
vertical=False,
height=200
)
action_plot = gr.Plot(label="Robot Motion Preview")
predict_btn.click(
fn=predict_pattern,
inputs=input_image,
outputs=[output_text, prob_plot, action_plot]
)
gr.Markdown("""
## How it works:
1. The model analyzes the image to detect guitar and robot positions
2. It classifies the appropriate strumming pattern (rock, folk, or ballad)
3. Robot control parameters are generated based on the pattern
## Integration:
```python
from transformers import AutoModelForImageClassification, AutoImageProcessor
model = AutoModelForImageClassification.from_pretrained("aabyzov/guitar-robot-classifier")
processor = AutoImageProcessor.from_pretrained("aabyzov/guitar-robot-classifier")
```
Built for LeRobot Hackathon 2024 🤖
""")
if __name__ == "__main__":
demo.launch()