Bomme commited on
Commit
6183d1a
·
1 Parent(s): d78a998

init from public repo

Browse files
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assests/*.mp3 filter=lfs diff=lfs merge=lfs -text
37
+ assets/Lazuli_Bunting_yell-YELLLAZB20160625SM303143.mp3 filter=lfs diff=lfs merge=lfs -text
38
+ assets/nri-GreenTreeFrogEvergladesNP.mp3 filter=lfs diff=lfs merge=lfs -text
39
+ assets/yell-YELLAMRO20160506SM3.mp3 filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import tempfile
3
+ from collections import Counter
4
+ from pathlib import Path
5
+ from typing import Literal
6
+
7
+ import gradio as gr
8
+ import torch
9
+
10
+ from NatureLM.config import Config
11
+ from NatureLM.models.NatureLM import NatureLM
12
+ from NatureLM.utils import generate_sample_batches, prepare_sample_waveforms
13
+
14
+ CONFIG: Config = None
15
+ MODEL: NatureLM = None
16
+
17
+
18
+ def prompt_lm(audios: list[str], messages: list[dict[str, str]]):
19
+ cuda_enabled = torch.cuda.is_available()
20
+ samples = prepare_sample_waveforms(audios, cuda_enabled)
21
+ prompt_text = MODEL.llama_tokenizer.apply_chat_template(
22
+ messages, tokenize=False, add_generation_prompt=True
23
+ ).removeprefix(MODEL.llama_tokenizer.bos_token)
24
+
25
+ prompt_text = re.sub(
26
+ r"<\|start_header_id\|>system<\|end_header_id\|>\n\nCutting Knowledge Date: [^\n]+\nToday Date: [^\n]+\n\n<\|eot_id\|>",
27
+ "",
28
+ prompt_text,
29
+ ) # exclude the system header from the prompt
30
+ prompt_text = re.sub("\\n", r"\\n", prompt_text) # FIXME this is a hack to fix the issue #34
31
+
32
+ print(f"{prompt_text=}")
33
+ with torch.cuda.amp.autocast(dtype=torch.float16):
34
+ llm_answer = MODEL.generate(samples, CONFIG.generate, prompts=[prompt_text])
35
+ return llm_answer[0]
36
+
37
+
38
+ def _multimodal_textbox_factory():
39
+ return gr.MultimodalTextbox(
40
+ value=None,
41
+ interactive=True,
42
+ file_count="multiple",
43
+ placeholder="Enter message or upload file...",
44
+ show_label=False,
45
+ submit_btn="Add input",
46
+ file_types=["audio"],
47
+ )
48
+
49
+
50
+ def user_message(content):
51
+ return {"role": "user", "content": content}
52
+
53
+
54
+ def add_message(history, message):
55
+ for x in message["files"]:
56
+ history.append(user_message({"path": x}))
57
+ if message["text"]:
58
+ history.append(user_message(message["text"]))
59
+ return history, _multimodal_textbox_factory()
60
+
61
+
62
+ def combine_model_inputs(msgs: list[dict[str, str]]) -> dict[str, list[str]]:
63
+ messages = []
64
+ files = []
65
+ for msg in msgs:
66
+ print(msg, messages, files)
67
+ match msg:
68
+ case {"content": (path,)}:
69
+ messages.append({"role": msg["role"], "content": "<Audio><AudioHere></Audio> "})
70
+ files.append(path)
71
+ case _:
72
+ messages.append(msg)
73
+ joined_messages = []
74
+ # join consecutive messages from the same role
75
+ for msg in messages:
76
+ if joined_messages and joined_messages[-1]["role"] == msg["role"]:
77
+ joined_messages[-1]["content"] += msg["content"]
78
+ else:
79
+ joined_messages.append(msg)
80
+
81
+ return {"messages": joined_messages, "files": files}
82
+
83
+
84
+ def bot_response(history: list):
85
+ print(type(history))
86
+ combined_inputs = combine_model_inputs(history)
87
+ response = prompt_lm(combined_inputs["files"], combined_inputs["messages"])
88
+ history.append({"role": "assistant", "content": response})
89
+
90
+ return history
91
+
92
+
93
+ def _chat_tab(examples):
94
+ chatbot = gr.Chatbot(
95
+ label="Model inputs",
96
+ elem_id="chatbot",
97
+ bubble_full_width=False,
98
+ type="messages",
99
+ render_markdown=False,
100
+ # editable="user", # disable because of https://github.com/gradio-app/gradio/issues/10320
101
+ resizeable=True,
102
+ )
103
+
104
+ chat_input = _multimodal_textbox_factory()
105
+ send_all = gr.Button("Send all", elem_id="send-all")
106
+ clear_button = gr.ClearButton(components=[chatbot, chat_input], visible=False)
107
+
108
+ chat_input.submit(add_message, [chatbot, chat_input], [chatbot, chat_input])
109
+ bot_msg = send_all.click(
110
+ bot_response,
111
+ [chatbot],
112
+ [chatbot],
113
+ api_name="bot_response",
114
+ )
115
+
116
+ bot_msg.then(lambda: gr.ClearButton(visible=True), None, [clear_button])
117
+ clear_button.click(lambda: gr.ClearButton(visible=False), None, [clear_button])
118
+
119
+ gr.Examples(
120
+ list(examples.values()),
121
+ chatbot,
122
+ chatbot,
123
+ example_labels=list(examples.keys()),
124
+ examples_per_page=20,
125
+ )
126
+
127
+
128
+ def summarize_batch_results(results):
129
+ summary = Counter(results)
130
+ summary_str = "\n".join(f"{k}: {v}" for k, v in summary.most_common())
131
+ return summary_str
132
+
133
+
134
+ def run_batch_inference(files, task, progress=gr.Progress()) -> str:
135
+ outputs = []
136
+ prompt = "<Audio><AudioHere></Audio> " + task
137
+
138
+ for file in progress.tqdm(files):
139
+ outputs.append(prompt_lm([file], [{"role": "user", "content": prompt}]))
140
+
141
+ batch_summary: str = summarize_batch_results(outputs)
142
+ report = f"Batch summary:\n{batch_summary}\n\n"
143
+ return report
144
+
145
+
146
+ def multi_extension_glob_mask(mask_base, *extensions):
147
+ mask_ext = ["[{}]".format("".join(set(c))) for c in zip(*extensions)]
148
+ if not mask_ext or len(set(len(e) for e in extensions)) > 1:
149
+ mask_ext.append("*")
150
+ return mask_base + "".join(mask_ext)
151
+
152
+
153
+ def _batch_tab(file_selection: Literal["upload", "explorer"] = "upload"):
154
+ if file_selection == "explorer":
155
+ files = gr.FileExplorer(
156
+ glob=multi_extension_glob_mask("**.", "mp3", "flac", "wav"),
157
+ label="Select audio files",
158
+ file_count="multiple",
159
+ )
160
+ elif file_selection == "upload":
161
+ files = gr.Files(label="Uploaded files", file_types=["audio"], height=300)
162
+ task = gr.Textbox(label="Task", placeholder="Enter task...", show_label=True)
163
+
164
+ process_btn = gr.Button("Process")
165
+ output = gr.TextArea()
166
+
167
+ process_btn.click(
168
+ run_batch_inference,
169
+ [files, task],
170
+ [output],
171
+ )
172
+
173
+
174
+ def to_raven_format(outputs: dict[int, str], chunk_len: int = 10) -> str:
175
+ def get_line(row, start, end, annotation):
176
+ return f"{row}\tSpectrogram 1\t1\t{start}\t{end}\t0\t8000\t{annotation}"
177
+
178
+ raven_output = ["Selection\tView\tChannel\tBegin Time (s)\tEnd Time (s)\tLow Freq (Hz)\tHigh Freq (Hz)\tAnnotation"]
179
+ current_offset = 0
180
+ last_label = ""
181
+ row = 1
182
+
183
+ # The "Selection" column is just the row number.
184
+ # The "view" column will always say "Spectrogram 1".
185
+ # Channel can always be "1".
186
+ # For the frequency bounds we can just use 0 and 1/2 the sample rate
187
+ for offset, label in sorted(outputs.items()):
188
+ if label != last_label and last_label:
189
+ raven_output.append(get_line(row, current_offset, offset, last_label))
190
+ current_offset = offset
191
+ row += 1
192
+ if not last_label:
193
+ current_offset = offset
194
+ if label != "None":
195
+ last_label = label
196
+ else:
197
+ last_label = ""
198
+ if last_label:
199
+ raven_output.append(get_line(row, current_offset, current_offset + chunk_len, last_label))
200
+
201
+ return "\n".join(raven_output)
202
+
203
+
204
+ def _run_long_recording_inference(file, task, chunk_len: int = 10, hop_len: int = 5, progress=gr.Progress()):
205
+ cuda_enabled = torch.cuda.is_available()
206
+ outputs = {}
207
+ offset = 0
208
+
209
+ prompt = f"<Audio><AudioHere></Audio> {task}"
210
+ prompt = MODEL.prompt_template.format(prompt)
211
+
212
+ for batch in progress.tqdm(generate_sample_batches(file, cuda_enabled, chunk_len=chunk_len, hop_len=hop_len)):
213
+ prompt_strs = [prompt] * len(batch["audio_chunk_sizes"])
214
+ with torch.cuda.amp.autocast(dtype=torch.float16):
215
+ llm_answers = MODEL.generate(batch, CONFIG.generate, prompts=prompt_strs)
216
+ for answer in llm_answers:
217
+ outputs[offset] = answer
218
+ offset += hop_len
219
+
220
+ report = f"Number of chunks: {len(outputs)}\n\n"
221
+ for offset in sorted(outputs.keys()):
222
+ report += f"{offset:02d}s:\t{outputs[offset]}\n"
223
+
224
+ raven_output = to_raven_format(outputs, chunk_len=chunk_len)
225
+ with tempfile.NamedTemporaryFile(mode="w", prefix="raven-", suffix=".txt", delete=False) as f:
226
+ f.write(raven_output)
227
+ raven_file = f.name
228
+
229
+ return report, raven_file
230
+
231
+
232
+ def _long_recording_tab():
233
+ audio_input = gr.Audio(label="Upload audio file", type="filepath")
234
+ task = gr.Dropdown(
235
+ [
236
+ "What are the common names for the species in the audio, if any?",
237
+ "Caption the audio.",
238
+ "Caption the audio, using the scientific name for any animal species.",
239
+ "Caption the audio, using the common name for any animal species.",
240
+ "What is the scientific name for the focal species in the audio?",
241
+ "What is the common name for the focal species in the audio?",
242
+ "What is the family of the focal species in the audio?",
243
+ "What is the genus of the focal species in the audio?",
244
+ "What is the taxonomic name of the focal species in the audio?",
245
+ "What call types are heard from the focal species in the audio?",
246
+ "What is the life stage of the focal species in the audio?",
247
+ ],
248
+ label="Tasks",
249
+ allow_custom_value=True,
250
+ )
251
+ with gr.Accordion("Advanced options", open=False):
252
+ hop_len = gr.Slider(1, 10, 5, label="Hop length (seconds)", step=1)
253
+ chunk_len = gr.Slider(1, 10, 10, label="Chunk length (seconds)", step=1)
254
+ process_btn = gr.Button("Process")
255
+ output = gr.TextArea()
256
+ download_raven = gr.DownloadButton("Download Raven file")
257
+
258
+ process_btn.click(
259
+ _run_long_recording_inference,
260
+ [audio_input, task, chunk_len, hop_len],
261
+ [output, download_raven],
262
+ )
263
+
264
+
265
+ def main(
266
+ assets_dir: Path,
267
+ cfg_path: str | Path,
268
+ options: list[str] = [],
269
+ device: str = "cuda",
270
+ ):
271
+ cfg = Config.from_sources(yaml_file=cfg_path, cli_args=options)
272
+ model = NatureLM.from_pretrained("EarthSpeciesProject/NatureLM-audio")
273
+ model.to(device)
274
+ model.eval()
275
+
276
+ global MODEL, CONFIG
277
+ MODEL = model
278
+ CONFIG = cfg
279
+
280
+ laz_audio = assets_dir / "Lazuli_Bunting_yell-YELLLAZB20160625SM303143.mp3"
281
+ frog_audio = assets_dir / "nri-GreenTreeFrogEvergladesNP.mp3"
282
+ robin_audio = assets_dir / "yell-YELLAMRO20160506SM3.mp3"
283
+
284
+ examples = {
285
+ "Caption the audio (Lazuli Bunting)": [
286
+ [
287
+ user_message({"path": str(laz_audio)}),
288
+ user_message("Caption the audio."),
289
+ ]
290
+ ],
291
+ "Caption the audio (Green Tree Frog)": [
292
+ [
293
+ user_message({"path": str(frog_audio)}),
294
+ user_message("Caption the audio, using the common name for any animal species."),
295
+ ]
296
+ ],
297
+ "Caption the audio (American Robin)": [
298
+ [
299
+ user_message({"path": str(robin_audio)}),
300
+ user_message("Caption the audio."),
301
+ ]
302
+ ],
303
+ }
304
+
305
+ with gr.Blocks(title="NatureLM-audio", theme=gr.themes.Default(primary_hue="slate")) as app:
306
+ with gr.Tabs():
307
+ with gr.Tab("Chat"):
308
+ _chat_tab(examples)
309
+ with gr.Tab("Batch"):
310
+ _batch_tab()
311
+ with gr.Tab("Long Recording"):
312
+ _long_recording_tab()
313
+
314
+ app.launch(
315
+ favicon_path=str(assets_dir / "esp_favicon.png"),
316
+ )
317
+
318
+ if __name__ == "__main__":
319
+ import argparse
320
+
321
+ parser = argparse.ArgumentParser(description="NatureLM-audio Gradio app")
322
+ parser.add_argument(
323
+ "--assets-dir",
324
+ type=Path,
325
+ default=Path(__file__).parent / "assets",
326
+ help="Directory containing the assets (favicon, examples, etc.)",
327
+ )
328
+ parser.add_argument(
329
+ "--cfg-path",
330
+ type=str,
331
+ default=Path(__file__).parent / "configs/inference.yml",
332
+ help="Path to the config file",
333
+ )
334
+ parser.add_argument(
335
+ "--options",
336
+ nargs="*",
337
+ default=[],
338
+ help="Additional options to pass to the config file",
339
+ )
340
+ args = parser.parse_args()
341
+
342
+ main(args.assets_dir, args.cfg_path, args.options)
assets/Lazuli_Bunting_yell-YELLLAZB20160625SM303143.mp3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6a67960286021e58ffab2d3e4b67b7e20d08b530018c64c6afefe4aae5ff28be
3
+ size 316920
assets/esp_favicon.png ADDED
assets/nri-GreenTreeFrogEvergladesNP.mp3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3004b02bd1793db81f5e6ddfe2f805dbd587af3c0d03edbedec2ad23e92660dd
3
+ size 162234
assets/yell-YELLAMRO20160506SM3.mp3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7a2700bbe2233505ccf592e9e06a4b196a0feb4d2d7a4773ed5f2f110696a001
3
+ size 598352
configs/inference.yml ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ llama_path: "meta-llama/Meta-Llama-3.1-8B-Instruct"
3
+
4
+ freeze_beats: True
5
+
6
+ use_audio_Qformer: True
7
+ max_pooling: False
8
+ downsample_factor: 8
9
+ freeze_audio_QFormer: False
10
+ window_level_Qformer: True
11
+ num_audio_query_token: 1
12
+ second_per_window: 0.333333
13
+ second_stride: 0.333333
14
+
15
+ audio_llama_proj_model: ""
16
+ freeze_audio_llama_proj: False
17
+
18
+ lora: True
19
+ lora_rank: 32
20
+ lora_alpha: 32
21
+ lora_dropout: 0.1
22
+
23
+ prompt_template: "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
24
+ max_txt_len: 160
25
+ end_sym: <|end_of_text|>
26
+
27
+ beats_cfg:
28
+ input_patch_size: 16
29
+ embed_dim: 512
30
+ conv_bias: False
31
+ encoder_layers: 12
32
+ encoder_embed_dim: 768
33
+ encoder_ffn_embed_dim: 3072
34
+ encoder_attention_heads: 12
35
+ activation_fn: "gelu"
36
+ layer_wise_gradient_decay_ratio: 0.6
37
+ layer_norm_first: False
38
+ deep_norm: True
39
+ dropout: 0.0
40
+ attention_dropout: 0.0
41
+ activation_dropout: 0.0
42
+ encoder_layerdrop: 0.05
43
+ dropout_input: 0.0
44
+ conv_pos: 128
45
+ conv_pos_groups: 16
46
+ relative_position_embedding: True
47
+ num_buckets: 320
48
+ max_distance: 800
49
+ gru_rel_pos: True
50
+ finetuned_model: True
51
+ predictor_dropout: 0.0
52
+ predictor_class: 527
53
+
54
+ generate:
55
+ max_new_tokens: 300
56
+ num_beams: 2
57
+ do_sample: False
58
+ min_length: 1
59
+ temperature: 0.1
60
+ repetition_penalty: 1.0
61
+ length_penalty: 1.0
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ git+https://github.com/Bomme/NatureLM-audio.git