Spaces:
Runtime error
Runtime error
File size: 5,817 Bytes
c9a2bfc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
import gradio as gr
import torch
from transformers import AutoImageProcessor, AutoModelForImageClassification
from PIL import Image
import numpy as np
# Load the trained model
model_name = "aabyzov/guitar-robot-classifier"
model = AutoModelForImageClassification.from_pretrained(model_name)
processor = AutoImageProcessor.from_pretrained(model_name)
# Pattern descriptions
pattern_info = {
'rock': {
'description': 'Energetic rock strumming with strong downstrokes',
'bpm': '120-140',
'dynamics': 'Forte (loud)',
'technique': 'Power chords with palm muting'
},
'folk': {
'description': 'Gentle folk pattern with bass note emphasis',
'bpm': '80-100',
'dynamics': 'Mezzo-forte (medium)',
'technique': 'Fingerstyle or light pick'
},
'ballad': {
'description': 'Slow, emotional strumming for ballads',
'bpm': '60-80',
'dynamics': 'Piano (soft)',
'technique': 'Gentle brushing with occasional accents'
}
}
def predict_pattern(image):
"""Predict guitar strumming pattern from image"""
if image is None:
return "Please upload an image", None, None
# Process image
inputs = processor(images=image, return_tensors="pt")
# Get prediction
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_id = logits.argmax(-1).item()
# Get probabilities
probs = torch.nn.functional.softmax(logits, dim=-1).squeeze()
# Get pattern name
pattern = model.config.id2label[predicted_id]
confidence = probs[predicted_id].item()
# Create detailed output
result_text = f"**Detected Pattern:** {pattern.upper()}\n"
result_text += f"**Confidence:** {confidence:.1%}\n\n"
result_text += f"**Description:** {pattern_info[pattern]['description']}\n"
result_text += f"**Recommended BPM:** {pattern_info[pattern]['bpm']}\n"
result_text += f"**Dynamics:** {pattern_info[pattern]['dynamics']}\n"
result_text += f"**Technique:** {pattern_info[pattern]['technique']}"
# Create probability chart
prob_data = {
'Pattern': ['Rock', 'Folk', 'Ballad'],
'Probability': [probs[0].item(), probs[1].item(), probs[2].item()]
}
# Generate robot action preview
action = generate_action_preview(pattern)
return result_text, prob_data, action
def generate_action_preview(pattern):
"""Generate a simple visualization of robot action"""
# Create a simple plot showing strumming motion
import matplotlib.pyplot as plt
fig, ax = plt.subplots(figsize=(8, 4))
# Generate waveform based on pattern
t = np.linspace(0, 4, 1000)
if pattern == 'rock':
# Fast, strong strumming
wave = 0.8 * np.sin(4 * np.pi * t) + 0.2 * np.sin(8 * np.pi * t)
elif pattern == 'folk':
# Moderate, smooth strumming
wave = 0.5 * np.sin(2 * np.pi * t) + 0.1 * np.sin(6 * np.pi * t)
else: # ballad
# Slow, gentle strumming
wave = 0.3 * np.sin(1 * np.pi * t)
ax.plot(t, wave, 'b-', linewidth=2)
ax.set_xlabel('Time (seconds)')
ax.set_ylabel('Strumming Motion')
ax.set_title(f'{pattern.capitalize()} Pattern - Robot Wrist Motion')
ax.grid(True, alpha=0.3)
ax.set_ylim(-1, 1)
plt.tight_layout()
return fig
# Create Gradio interface
with gr.Blocks(title="Guitar Robot Pattern Classifier") as demo:
gr.Markdown("""
# 🎸 Guitar Robot Pattern Classifier
This model classifies guitar strumming patterns for the SO-100 robot arm.
Upload an image to detect the strumming pattern and get robot control recommendations.
**Model:** [aabyzov/guitar-robot-classifier](https://huggingface.co/aabyzov/guitar-robot-classifier)
**Dataset:** [aabyzov/guitar-robot-realistic-v1](https://huggingface.co/datasets/aabyzov/guitar-robot-realistic-v1)
""")
with gr.Row():
with gr.Column():
input_image = gr.Image(
label="Upload Image",
type="pil",
elem_id="input-image"
)
predict_btn = gr.Button("Analyze Pattern", variant="primary")
gr.Examples(
examples=[
["examples/rock_example.jpg"],
["examples/folk_example.jpg"],
["examples/ballad_example.jpg"]
],
inputs=input_image,
label="Example Images"
)
with gr.Column():
output_text = gr.Markdown(label="Analysis Results")
prob_plot = gr.BarPlot(
label="Pattern Probabilities",
x="Pattern",
y="Probability",
vertical=False,
height=200
)
action_plot = gr.Plot(label="Robot Motion Preview")
predict_btn.click(
fn=predict_pattern,
inputs=input_image,
outputs=[output_text, prob_plot, action_plot]
)
gr.Markdown("""
## How it works:
1. The model analyzes the image to detect guitar and robot positions
2. It classifies the appropriate strumming pattern (rock, folk, or ballad)
3. Robot control parameters are generated based on the pattern
## Integration:
```python
from transformers import AutoModelForImageClassification, AutoImageProcessor
model = AutoModelForImageClassification.from_pretrained("aabyzov/guitar-robot-classifier")
processor = AutoImageProcessor.from_pretrained("aabyzov/guitar-robot-classifier")
```
Built for LeRobot Hackathon 2024 🤖
""")
if __name__ == "__main__":
demo.launch() |