File size: 5,817 Bytes
c9a2bfc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import gradio as gr
import torch
from transformers import AutoImageProcessor, AutoModelForImageClassification
from PIL import Image
import numpy as np

# Load the trained model
model_name = "aabyzov/guitar-robot-classifier"
model = AutoModelForImageClassification.from_pretrained(model_name)
processor = AutoImageProcessor.from_pretrained(model_name)

# Pattern descriptions
pattern_info = {
    'rock': {
        'description': 'Energetic rock strumming with strong downstrokes',
        'bpm': '120-140',
        'dynamics': 'Forte (loud)',
        'technique': 'Power chords with palm muting'
    },
    'folk': {
        'description': 'Gentle folk pattern with bass note emphasis',
        'bpm': '80-100',
        'dynamics': 'Mezzo-forte (medium)',
        'technique': 'Fingerstyle or light pick'
    },
    'ballad': {
        'description': 'Slow, emotional strumming for ballads',
        'bpm': '60-80',
        'dynamics': 'Piano (soft)',
        'technique': 'Gentle brushing with occasional accents'
    }
}

def predict_pattern(image):
    """Predict guitar strumming pattern from image"""
    if image is None:
        return "Please upload an image", None, None
    
    # Process image
    inputs = processor(images=image, return_tensors="pt")
    
    # Get prediction
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
        predicted_id = logits.argmax(-1).item()
        
    # Get probabilities
    probs = torch.nn.functional.softmax(logits, dim=-1).squeeze()
    
    # Get pattern name
    pattern = model.config.id2label[predicted_id]
    confidence = probs[predicted_id].item()
    
    # Create detailed output
    result_text = f"**Detected Pattern:** {pattern.upper()}\n"
    result_text += f"**Confidence:** {confidence:.1%}\n\n"
    result_text += f"**Description:** {pattern_info[pattern]['description']}\n"
    result_text += f"**Recommended BPM:** {pattern_info[pattern]['bpm']}\n"
    result_text += f"**Dynamics:** {pattern_info[pattern]['dynamics']}\n"
    result_text += f"**Technique:** {pattern_info[pattern]['technique']}"
    
    # Create probability chart
    prob_data = {
        'Pattern': ['Rock', 'Folk', 'Ballad'],
        'Probability': [probs[0].item(), probs[1].item(), probs[2].item()]
    }
    
    # Generate robot action preview
    action = generate_action_preview(pattern)
    
    return result_text, prob_data, action

def generate_action_preview(pattern):
    """Generate a simple visualization of robot action"""
    # Create a simple plot showing strumming motion
    import matplotlib.pyplot as plt
    
    fig, ax = plt.subplots(figsize=(8, 4))
    
    # Generate waveform based on pattern
    t = np.linspace(0, 4, 1000)
    
    if pattern == 'rock':
        # Fast, strong strumming
        wave = 0.8 * np.sin(4 * np.pi * t) + 0.2 * np.sin(8 * np.pi * t)
    elif pattern == 'folk':
        # Moderate, smooth strumming
        wave = 0.5 * np.sin(2 * np.pi * t) + 0.1 * np.sin(6 * np.pi * t)
    else:  # ballad
        # Slow, gentle strumming
        wave = 0.3 * np.sin(1 * np.pi * t)
    
    ax.plot(t, wave, 'b-', linewidth=2)
    ax.set_xlabel('Time (seconds)')
    ax.set_ylabel('Strumming Motion')
    ax.set_title(f'{pattern.capitalize()} Pattern - Robot Wrist Motion')
    ax.grid(True, alpha=0.3)
    ax.set_ylim(-1, 1)
    
    plt.tight_layout()
    
    return fig

# Create Gradio interface
with gr.Blocks(title="Guitar Robot Pattern Classifier") as demo:
    gr.Markdown("""
    # 🎸 Guitar Robot Pattern Classifier
    
    This model classifies guitar strumming patterns for the SO-100 robot arm.
    Upload an image to detect the strumming pattern and get robot control recommendations.
    
    **Model:** [aabyzov/guitar-robot-classifier](https://huggingface.co/aabyzov/guitar-robot-classifier)
    **Dataset:** [aabyzov/guitar-robot-realistic-v1](https://huggingface.co/datasets/aabyzov/guitar-robot-realistic-v1)
    """)
    
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(
                label="Upload Image",
                type="pil",
                elem_id="input-image"
            )
            
            predict_btn = gr.Button("Analyze Pattern", variant="primary")
            
            gr.Examples(
                examples=[
                    ["examples/rock_example.jpg"],
                    ["examples/folk_example.jpg"],
                    ["examples/ballad_example.jpg"]
                ],
                inputs=input_image,
                label="Example Images"
            )
        
        with gr.Column():
            output_text = gr.Markdown(label="Analysis Results")
            
            prob_plot = gr.BarPlot(
                label="Pattern Probabilities",
                x="Pattern",
                y="Probability",
                vertical=False,
                height=200
            )
            
            action_plot = gr.Plot(label="Robot Motion Preview")
    
    predict_btn.click(
        fn=predict_pattern,
        inputs=input_image,
        outputs=[output_text, prob_plot, action_plot]
    )
    
    gr.Markdown("""
    ## How it works:
    1. The model analyzes the image to detect guitar and robot positions
    2. It classifies the appropriate strumming pattern (rock, folk, or ballad)
    3. Robot control parameters are generated based on the pattern
    
    ## Integration:
    ```python
    from transformers import AutoModelForImageClassification, AutoImageProcessor
    
    model = AutoModelForImageClassification.from_pretrained("aabyzov/guitar-robot-classifier")
    processor = AutoImageProcessor.from_pretrained("aabyzov/guitar-robot-classifier")
    ```
    
    Built for LeRobot Hackathon 2024 🤖
    """)

if __name__ == "__main__":
    demo.launch()