Anton Abyzov commited on
Commit
c9a2bfc
·
1 Parent(s): 7399b63

Add guitar robot pattern classifier app

Browse files

- Gradio interface for testing the trained model
- Shows pattern classification (rock/folk/ballad)
- Displays confidence scores and robot motion preview
- Includes usage instructions and model links

.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # HuggingFace Spaces (keep their own git repos)
2
+ guitar-robot-trainer/.git/
3
+ autotrain-advanced/.git/
.idea/.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Default ignored files
2
+ /shelf/
3
+ /workspace.xml
4
+ # Editor-based HTTP Client requests
5
+ /httpRequests/
6
+ # Datasource local storage ignored files
7
+ /dataSources/
8
+ /dataSources.local.xml
.idea/guitar-robot-trainer.iml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <module type="WEB_MODULE" version="4">
3
+ <component name="NewModuleRootManager">
4
+ <content url="file://$MODULE_DIR$">
5
+ <excludeFolder url="file://$MODULE_DIR$/.tmp" />
6
+ <excludeFolder url="file://$MODULE_DIR$/temp" />
7
+ <excludeFolder url="file://$MODULE_DIR$/tmp" />
8
+ </content>
9
+ <orderEntry type="inheritedJdk" />
10
+ <orderEntry type="sourceFolder" forTests="false" />
11
+ </component>
12
+ </module>
.idea/misc.xml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="KubernetesApiPersistence">{}</component>
4
+ <component name="KubernetesApiProvider">{
5
+ &quot;isMigrated&quot;: true
6
+ }</component>
7
+ </project>
.idea/modules.xml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectModuleManager">
4
+ <modules>
5
+ <module fileurl="file://$PROJECT_DIR$/.idea/guitar-robot-trainer.iml" filepath="$PROJECT_DIR$/.idea/guitar-robot-trainer.iml" />
6
+ </modules>
7
+ </component>
8
+ </project>
.idea/vcs.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="VcsDirectoryMappings">
4
+ <mapping directory="" vcs="Git" />
5
+ </component>
6
+ </project>
app.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
4
+ from PIL import Image
5
+ import numpy as np
6
+
7
+ # Load the trained model
8
+ model_name = "aabyzov/guitar-robot-classifier"
9
+ model = AutoModelForImageClassification.from_pretrained(model_name)
10
+ processor = AutoImageProcessor.from_pretrained(model_name)
11
+
12
+ # Pattern descriptions
13
+ pattern_info = {
14
+ 'rock': {
15
+ 'description': 'Energetic rock strumming with strong downstrokes',
16
+ 'bpm': '120-140',
17
+ 'dynamics': 'Forte (loud)',
18
+ 'technique': 'Power chords with palm muting'
19
+ },
20
+ 'folk': {
21
+ 'description': 'Gentle folk pattern with bass note emphasis',
22
+ 'bpm': '80-100',
23
+ 'dynamics': 'Mezzo-forte (medium)',
24
+ 'technique': 'Fingerstyle or light pick'
25
+ },
26
+ 'ballad': {
27
+ 'description': 'Slow, emotional strumming for ballads',
28
+ 'bpm': '60-80',
29
+ 'dynamics': 'Piano (soft)',
30
+ 'technique': 'Gentle brushing with occasional accents'
31
+ }
32
+ }
33
+
34
+ def predict_pattern(image):
35
+ """Predict guitar strumming pattern from image"""
36
+ if image is None:
37
+ return "Please upload an image", None, None
38
+
39
+ # Process image
40
+ inputs = processor(images=image, return_tensors="pt")
41
+
42
+ # Get prediction
43
+ with torch.no_grad():
44
+ outputs = model(**inputs)
45
+ logits = outputs.logits
46
+ predicted_id = logits.argmax(-1).item()
47
+
48
+ # Get probabilities
49
+ probs = torch.nn.functional.softmax(logits, dim=-1).squeeze()
50
+
51
+ # Get pattern name
52
+ pattern = model.config.id2label[predicted_id]
53
+ confidence = probs[predicted_id].item()
54
+
55
+ # Create detailed output
56
+ result_text = f"**Detected Pattern:** {pattern.upper()}\n"
57
+ result_text += f"**Confidence:** {confidence:.1%}\n\n"
58
+ result_text += f"**Description:** {pattern_info[pattern]['description']}\n"
59
+ result_text += f"**Recommended BPM:** {pattern_info[pattern]['bpm']}\n"
60
+ result_text += f"**Dynamics:** {pattern_info[pattern]['dynamics']}\n"
61
+ result_text += f"**Technique:** {pattern_info[pattern]['technique']}"
62
+
63
+ # Create probability chart
64
+ prob_data = {
65
+ 'Pattern': ['Rock', 'Folk', 'Ballad'],
66
+ 'Probability': [probs[0].item(), probs[1].item(), probs[2].item()]
67
+ }
68
+
69
+ # Generate robot action preview
70
+ action = generate_action_preview(pattern)
71
+
72
+ return result_text, prob_data, action
73
+
74
+ def generate_action_preview(pattern):
75
+ """Generate a simple visualization of robot action"""
76
+ # Create a simple plot showing strumming motion
77
+ import matplotlib.pyplot as plt
78
+
79
+ fig, ax = plt.subplots(figsize=(8, 4))
80
+
81
+ # Generate waveform based on pattern
82
+ t = np.linspace(0, 4, 1000)
83
+
84
+ if pattern == 'rock':
85
+ # Fast, strong strumming
86
+ wave = 0.8 * np.sin(4 * np.pi * t) + 0.2 * np.sin(8 * np.pi * t)
87
+ elif pattern == 'folk':
88
+ # Moderate, smooth strumming
89
+ wave = 0.5 * np.sin(2 * np.pi * t) + 0.1 * np.sin(6 * np.pi * t)
90
+ else: # ballad
91
+ # Slow, gentle strumming
92
+ wave = 0.3 * np.sin(1 * np.pi * t)
93
+
94
+ ax.plot(t, wave, 'b-', linewidth=2)
95
+ ax.set_xlabel('Time (seconds)')
96
+ ax.set_ylabel('Strumming Motion')
97
+ ax.set_title(f'{pattern.capitalize()} Pattern - Robot Wrist Motion')
98
+ ax.grid(True, alpha=0.3)
99
+ ax.set_ylim(-1, 1)
100
+
101
+ plt.tight_layout()
102
+
103
+ return fig
104
+
105
+ # Create Gradio interface
106
+ with gr.Blocks(title="Guitar Robot Pattern Classifier") as demo:
107
+ gr.Markdown("""
108
+ # 🎸 Guitar Robot Pattern Classifier
109
+
110
+ This model classifies guitar strumming patterns for the SO-100 robot arm.
111
+ Upload an image to detect the strumming pattern and get robot control recommendations.
112
+
113
+ **Model:** [aabyzov/guitar-robot-classifier](https://huggingface.co/aabyzov/guitar-robot-classifier)
114
+ **Dataset:** [aabyzov/guitar-robot-realistic-v1](https://huggingface.co/datasets/aabyzov/guitar-robot-realistic-v1)
115
+ """)
116
+
117
+ with gr.Row():
118
+ with gr.Column():
119
+ input_image = gr.Image(
120
+ label="Upload Image",
121
+ type="pil",
122
+ elem_id="input-image"
123
+ )
124
+
125
+ predict_btn = gr.Button("Analyze Pattern", variant="primary")
126
+
127
+ gr.Examples(
128
+ examples=[
129
+ ["examples/rock_example.jpg"],
130
+ ["examples/folk_example.jpg"],
131
+ ["examples/ballad_example.jpg"]
132
+ ],
133
+ inputs=input_image,
134
+ label="Example Images"
135
+ )
136
+
137
+ with gr.Column():
138
+ output_text = gr.Markdown(label="Analysis Results")
139
+
140
+ prob_plot = gr.BarPlot(
141
+ label="Pattern Probabilities",
142
+ x="Pattern",
143
+ y="Probability",
144
+ vertical=False,
145
+ height=200
146
+ )
147
+
148
+ action_plot = gr.Plot(label="Robot Motion Preview")
149
+
150
+ predict_btn.click(
151
+ fn=predict_pattern,
152
+ inputs=input_image,
153
+ outputs=[output_text, prob_plot, action_plot]
154
+ )
155
+
156
+ gr.Markdown("""
157
+ ## How it works:
158
+ 1. The model analyzes the image to detect guitar and robot positions
159
+ 2. It classifies the appropriate strumming pattern (rock, folk, or ballad)
160
+ 3. Robot control parameters are generated based on the pattern
161
+
162
+ ## Integration:
163
+ ```python
164
+ from transformers import AutoModelForImageClassification, AutoImageProcessor
165
+
166
+ model = AutoModelForImageClassification.from_pretrained("aabyzov/guitar-robot-classifier")
167
+ processor = AutoImageProcessor.from_pretrained("aabyzov/guitar-robot-classifier")
168
+ ```
169
+
170
+ Built for LeRobot Hackathon 2024 🤖
171
+ """)
172
+
173
+ if __name__ == "__main__":
174
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio==5.34.0
2
+ transformers>=4.35.0
3
+ torch>=2.0.0
4
+ pillow
5
+ numpy
6
+ matplotlib