Spaces:
Build error
Build error
Bomme
commited on
Commit
·
6183d1a
1
Parent(s):
d78a998
init from public repo
Browse files- .gitattributes +4 -0
- app.py +342 -0
- assets/Lazuli_Bunting_yell-YELLLAZB20160625SM303143.mp3 +3 -0
- assets/esp_favicon.png +0 -0
- assets/nri-GreenTreeFrogEvergladesNP.mp3 +3 -0
- assets/yell-YELLAMRO20160506SM3.mp3 +3 -0
- configs/inference.yml +61 -0
- requirements.txt +1 -0
.gitattributes
CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
assests/*.mp3 filter=lfs diff=lfs merge=lfs -text
|
37 |
+
assets/Lazuli_Bunting_yell-YELLLAZB20160625SM303143.mp3 filter=lfs diff=lfs merge=lfs -text
|
38 |
+
assets/nri-GreenTreeFrogEvergladesNP.mp3 filter=lfs diff=lfs merge=lfs -text
|
39 |
+
assets/yell-YELLAMRO20160506SM3.mp3 filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import tempfile
|
3 |
+
from collections import Counter
|
4 |
+
from pathlib import Path
|
5 |
+
from typing import Literal
|
6 |
+
|
7 |
+
import gradio as gr
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from NatureLM.config import Config
|
11 |
+
from NatureLM.models.NatureLM import NatureLM
|
12 |
+
from NatureLM.utils import generate_sample_batches, prepare_sample_waveforms
|
13 |
+
|
14 |
+
CONFIG: Config = None
|
15 |
+
MODEL: NatureLM = None
|
16 |
+
|
17 |
+
|
18 |
+
def prompt_lm(audios: list[str], messages: list[dict[str, str]]):
|
19 |
+
cuda_enabled = torch.cuda.is_available()
|
20 |
+
samples = prepare_sample_waveforms(audios, cuda_enabled)
|
21 |
+
prompt_text = MODEL.llama_tokenizer.apply_chat_template(
|
22 |
+
messages, tokenize=False, add_generation_prompt=True
|
23 |
+
).removeprefix(MODEL.llama_tokenizer.bos_token)
|
24 |
+
|
25 |
+
prompt_text = re.sub(
|
26 |
+
r"<\|start_header_id\|>system<\|end_header_id\|>\n\nCutting Knowledge Date: [^\n]+\nToday Date: [^\n]+\n\n<\|eot_id\|>",
|
27 |
+
"",
|
28 |
+
prompt_text,
|
29 |
+
) # exclude the system header from the prompt
|
30 |
+
prompt_text = re.sub("\\n", r"\\n", prompt_text) # FIXME this is a hack to fix the issue #34
|
31 |
+
|
32 |
+
print(f"{prompt_text=}")
|
33 |
+
with torch.cuda.amp.autocast(dtype=torch.float16):
|
34 |
+
llm_answer = MODEL.generate(samples, CONFIG.generate, prompts=[prompt_text])
|
35 |
+
return llm_answer[0]
|
36 |
+
|
37 |
+
|
38 |
+
def _multimodal_textbox_factory():
|
39 |
+
return gr.MultimodalTextbox(
|
40 |
+
value=None,
|
41 |
+
interactive=True,
|
42 |
+
file_count="multiple",
|
43 |
+
placeholder="Enter message or upload file...",
|
44 |
+
show_label=False,
|
45 |
+
submit_btn="Add input",
|
46 |
+
file_types=["audio"],
|
47 |
+
)
|
48 |
+
|
49 |
+
|
50 |
+
def user_message(content):
|
51 |
+
return {"role": "user", "content": content}
|
52 |
+
|
53 |
+
|
54 |
+
def add_message(history, message):
|
55 |
+
for x in message["files"]:
|
56 |
+
history.append(user_message({"path": x}))
|
57 |
+
if message["text"]:
|
58 |
+
history.append(user_message(message["text"]))
|
59 |
+
return history, _multimodal_textbox_factory()
|
60 |
+
|
61 |
+
|
62 |
+
def combine_model_inputs(msgs: list[dict[str, str]]) -> dict[str, list[str]]:
|
63 |
+
messages = []
|
64 |
+
files = []
|
65 |
+
for msg in msgs:
|
66 |
+
print(msg, messages, files)
|
67 |
+
match msg:
|
68 |
+
case {"content": (path,)}:
|
69 |
+
messages.append({"role": msg["role"], "content": "<Audio><AudioHere></Audio> "})
|
70 |
+
files.append(path)
|
71 |
+
case _:
|
72 |
+
messages.append(msg)
|
73 |
+
joined_messages = []
|
74 |
+
# join consecutive messages from the same role
|
75 |
+
for msg in messages:
|
76 |
+
if joined_messages and joined_messages[-1]["role"] == msg["role"]:
|
77 |
+
joined_messages[-1]["content"] += msg["content"]
|
78 |
+
else:
|
79 |
+
joined_messages.append(msg)
|
80 |
+
|
81 |
+
return {"messages": joined_messages, "files": files}
|
82 |
+
|
83 |
+
|
84 |
+
def bot_response(history: list):
|
85 |
+
print(type(history))
|
86 |
+
combined_inputs = combine_model_inputs(history)
|
87 |
+
response = prompt_lm(combined_inputs["files"], combined_inputs["messages"])
|
88 |
+
history.append({"role": "assistant", "content": response})
|
89 |
+
|
90 |
+
return history
|
91 |
+
|
92 |
+
|
93 |
+
def _chat_tab(examples):
|
94 |
+
chatbot = gr.Chatbot(
|
95 |
+
label="Model inputs",
|
96 |
+
elem_id="chatbot",
|
97 |
+
bubble_full_width=False,
|
98 |
+
type="messages",
|
99 |
+
render_markdown=False,
|
100 |
+
# editable="user", # disable because of https://github.com/gradio-app/gradio/issues/10320
|
101 |
+
resizeable=True,
|
102 |
+
)
|
103 |
+
|
104 |
+
chat_input = _multimodal_textbox_factory()
|
105 |
+
send_all = gr.Button("Send all", elem_id="send-all")
|
106 |
+
clear_button = gr.ClearButton(components=[chatbot, chat_input], visible=False)
|
107 |
+
|
108 |
+
chat_input.submit(add_message, [chatbot, chat_input], [chatbot, chat_input])
|
109 |
+
bot_msg = send_all.click(
|
110 |
+
bot_response,
|
111 |
+
[chatbot],
|
112 |
+
[chatbot],
|
113 |
+
api_name="bot_response",
|
114 |
+
)
|
115 |
+
|
116 |
+
bot_msg.then(lambda: gr.ClearButton(visible=True), None, [clear_button])
|
117 |
+
clear_button.click(lambda: gr.ClearButton(visible=False), None, [clear_button])
|
118 |
+
|
119 |
+
gr.Examples(
|
120 |
+
list(examples.values()),
|
121 |
+
chatbot,
|
122 |
+
chatbot,
|
123 |
+
example_labels=list(examples.keys()),
|
124 |
+
examples_per_page=20,
|
125 |
+
)
|
126 |
+
|
127 |
+
|
128 |
+
def summarize_batch_results(results):
|
129 |
+
summary = Counter(results)
|
130 |
+
summary_str = "\n".join(f"{k}: {v}" for k, v in summary.most_common())
|
131 |
+
return summary_str
|
132 |
+
|
133 |
+
|
134 |
+
def run_batch_inference(files, task, progress=gr.Progress()) -> str:
|
135 |
+
outputs = []
|
136 |
+
prompt = "<Audio><AudioHere></Audio> " + task
|
137 |
+
|
138 |
+
for file in progress.tqdm(files):
|
139 |
+
outputs.append(prompt_lm([file], [{"role": "user", "content": prompt}]))
|
140 |
+
|
141 |
+
batch_summary: str = summarize_batch_results(outputs)
|
142 |
+
report = f"Batch summary:\n{batch_summary}\n\n"
|
143 |
+
return report
|
144 |
+
|
145 |
+
|
146 |
+
def multi_extension_glob_mask(mask_base, *extensions):
|
147 |
+
mask_ext = ["[{}]".format("".join(set(c))) for c in zip(*extensions)]
|
148 |
+
if not mask_ext or len(set(len(e) for e in extensions)) > 1:
|
149 |
+
mask_ext.append("*")
|
150 |
+
return mask_base + "".join(mask_ext)
|
151 |
+
|
152 |
+
|
153 |
+
def _batch_tab(file_selection: Literal["upload", "explorer"] = "upload"):
|
154 |
+
if file_selection == "explorer":
|
155 |
+
files = gr.FileExplorer(
|
156 |
+
glob=multi_extension_glob_mask("**.", "mp3", "flac", "wav"),
|
157 |
+
label="Select audio files",
|
158 |
+
file_count="multiple",
|
159 |
+
)
|
160 |
+
elif file_selection == "upload":
|
161 |
+
files = gr.Files(label="Uploaded files", file_types=["audio"], height=300)
|
162 |
+
task = gr.Textbox(label="Task", placeholder="Enter task...", show_label=True)
|
163 |
+
|
164 |
+
process_btn = gr.Button("Process")
|
165 |
+
output = gr.TextArea()
|
166 |
+
|
167 |
+
process_btn.click(
|
168 |
+
run_batch_inference,
|
169 |
+
[files, task],
|
170 |
+
[output],
|
171 |
+
)
|
172 |
+
|
173 |
+
|
174 |
+
def to_raven_format(outputs: dict[int, str], chunk_len: int = 10) -> str:
|
175 |
+
def get_line(row, start, end, annotation):
|
176 |
+
return f"{row}\tSpectrogram 1\t1\t{start}\t{end}\t0\t8000\t{annotation}"
|
177 |
+
|
178 |
+
raven_output = ["Selection\tView\tChannel\tBegin Time (s)\tEnd Time (s)\tLow Freq (Hz)\tHigh Freq (Hz)\tAnnotation"]
|
179 |
+
current_offset = 0
|
180 |
+
last_label = ""
|
181 |
+
row = 1
|
182 |
+
|
183 |
+
# The "Selection" column is just the row number.
|
184 |
+
# The "view" column will always say "Spectrogram 1".
|
185 |
+
# Channel can always be "1".
|
186 |
+
# For the frequency bounds we can just use 0 and 1/2 the sample rate
|
187 |
+
for offset, label in sorted(outputs.items()):
|
188 |
+
if label != last_label and last_label:
|
189 |
+
raven_output.append(get_line(row, current_offset, offset, last_label))
|
190 |
+
current_offset = offset
|
191 |
+
row += 1
|
192 |
+
if not last_label:
|
193 |
+
current_offset = offset
|
194 |
+
if label != "None":
|
195 |
+
last_label = label
|
196 |
+
else:
|
197 |
+
last_label = ""
|
198 |
+
if last_label:
|
199 |
+
raven_output.append(get_line(row, current_offset, current_offset + chunk_len, last_label))
|
200 |
+
|
201 |
+
return "\n".join(raven_output)
|
202 |
+
|
203 |
+
|
204 |
+
def _run_long_recording_inference(file, task, chunk_len: int = 10, hop_len: int = 5, progress=gr.Progress()):
|
205 |
+
cuda_enabled = torch.cuda.is_available()
|
206 |
+
outputs = {}
|
207 |
+
offset = 0
|
208 |
+
|
209 |
+
prompt = f"<Audio><AudioHere></Audio> {task}"
|
210 |
+
prompt = MODEL.prompt_template.format(prompt)
|
211 |
+
|
212 |
+
for batch in progress.tqdm(generate_sample_batches(file, cuda_enabled, chunk_len=chunk_len, hop_len=hop_len)):
|
213 |
+
prompt_strs = [prompt] * len(batch["audio_chunk_sizes"])
|
214 |
+
with torch.cuda.amp.autocast(dtype=torch.float16):
|
215 |
+
llm_answers = MODEL.generate(batch, CONFIG.generate, prompts=prompt_strs)
|
216 |
+
for answer in llm_answers:
|
217 |
+
outputs[offset] = answer
|
218 |
+
offset += hop_len
|
219 |
+
|
220 |
+
report = f"Number of chunks: {len(outputs)}\n\n"
|
221 |
+
for offset in sorted(outputs.keys()):
|
222 |
+
report += f"{offset:02d}s:\t{outputs[offset]}\n"
|
223 |
+
|
224 |
+
raven_output = to_raven_format(outputs, chunk_len=chunk_len)
|
225 |
+
with tempfile.NamedTemporaryFile(mode="w", prefix="raven-", suffix=".txt", delete=False) as f:
|
226 |
+
f.write(raven_output)
|
227 |
+
raven_file = f.name
|
228 |
+
|
229 |
+
return report, raven_file
|
230 |
+
|
231 |
+
|
232 |
+
def _long_recording_tab():
|
233 |
+
audio_input = gr.Audio(label="Upload audio file", type="filepath")
|
234 |
+
task = gr.Dropdown(
|
235 |
+
[
|
236 |
+
"What are the common names for the species in the audio, if any?",
|
237 |
+
"Caption the audio.",
|
238 |
+
"Caption the audio, using the scientific name for any animal species.",
|
239 |
+
"Caption the audio, using the common name for any animal species.",
|
240 |
+
"What is the scientific name for the focal species in the audio?",
|
241 |
+
"What is the common name for the focal species in the audio?",
|
242 |
+
"What is the family of the focal species in the audio?",
|
243 |
+
"What is the genus of the focal species in the audio?",
|
244 |
+
"What is the taxonomic name of the focal species in the audio?",
|
245 |
+
"What call types are heard from the focal species in the audio?",
|
246 |
+
"What is the life stage of the focal species in the audio?",
|
247 |
+
],
|
248 |
+
label="Tasks",
|
249 |
+
allow_custom_value=True,
|
250 |
+
)
|
251 |
+
with gr.Accordion("Advanced options", open=False):
|
252 |
+
hop_len = gr.Slider(1, 10, 5, label="Hop length (seconds)", step=1)
|
253 |
+
chunk_len = gr.Slider(1, 10, 10, label="Chunk length (seconds)", step=1)
|
254 |
+
process_btn = gr.Button("Process")
|
255 |
+
output = gr.TextArea()
|
256 |
+
download_raven = gr.DownloadButton("Download Raven file")
|
257 |
+
|
258 |
+
process_btn.click(
|
259 |
+
_run_long_recording_inference,
|
260 |
+
[audio_input, task, chunk_len, hop_len],
|
261 |
+
[output, download_raven],
|
262 |
+
)
|
263 |
+
|
264 |
+
|
265 |
+
def main(
|
266 |
+
assets_dir: Path,
|
267 |
+
cfg_path: str | Path,
|
268 |
+
options: list[str] = [],
|
269 |
+
device: str = "cuda",
|
270 |
+
):
|
271 |
+
cfg = Config.from_sources(yaml_file=cfg_path, cli_args=options)
|
272 |
+
model = NatureLM.from_pretrained("EarthSpeciesProject/NatureLM-audio")
|
273 |
+
model.to(device)
|
274 |
+
model.eval()
|
275 |
+
|
276 |
+
global MODEL, CONFIG
|
277 |
+
MODEL = model
|
278 |
+
CONFIG = cfg
|
279 |
+
|
280 |
+
laz_audio = assets_dir / "Lazuli_Bunting_yell-YELLLAZB20160625SM303143.mp3"
|
281 |
+
frog_audio = assets_dir / "nri-GreenTreeFrogEvergladesNP.mp3"
|
282 |
+
robin_audio = assets_dir / "yell-YELLAMRO20160506SM3.mp3"
|
283 |
+
|
284 |
+
examples = {
|
285 |
+
"Caption the audio (Lazuli Bunting)": [
|
286 |
+
[
|
287 |
+
user_message({"path": str(laz_audio)}),
|
288 |
+
user_message("Caption the audio."),
|
289 |
+
]
|
290 |
+
],
|
291 |
+
"Caption the audio (Green Tree Frog)": [
|
292 |
+
[
|
293 |
+
user_message({"path": str(frog_audio)}),
|
294 |
+
user_message("Caption the audio, using the common name for any animal species."),
|
295 |
+
]
|
296 |
+
],
|
297 |
+
"Caption the audio (American Robin)": [
|
298 |
+
[
|
299 |
+
user_message({"path": str(robin_audio)}),
|
300 |
+
user_message("Caption the audio."),
|
301 |
+
]
|
302 |
+
],
|
303 |
+
}
|
304 |
+
|
305 |
+
with gr.Blocks(title="NatureLM-audio", theme=gr.themes.Default(primary_hue="slate")) as app:
|
306 |
+
with gr.Tabs():
|
307 |
+
with gr.Tab("Chat"):
|
308 |
+
_chat_tab(examples)
|
309 |
+
with gr.Tab("Batch"):
|
310 |
+
_batch_tab()
|
311 |
+
with gr.Tab("Long Recording"):
|
312 |
+
_long_recording_tab()
|
313 |
+
|
314 |
+
app.launch(
|
315 |
+
favicon_path=str(assets_dir / "esp_favicon.png"),
|
316 |
+
)
|
317 |
+
|
318 |
+
if __name__ == "__main__":
|
319 |
+
import argparse
|
320 |
+
|
321 |
+
parser = argparse.ArgumentParser(description="NatureLM-audio Gradio app")
|
322 |
+
parser.add_argument(
|
323 |
+
"--assets-dir",
|
324 |
+
type=Path,
|
325 |
+
default=Path(__file__).parent / "assets",
|
326 |
+
help="Directory containing the assets (favicon, examples, etc.)",
|
327 |
+
)
|
328 |
+
parser.add_argument(
|
329 |
+
"--cfg-path",
|
330 |
+
type=str,
|
331 |
+
default=Path(__file__).parent / "configs/inference.yml",
|
332 |
+
help="Path to the config file",
|
333 |
+
)
|
334 |
+
parser.add_argument(
|
335 |
+
"--options",
|
336 |
+
nargs="*",
|
337 |
+
default=[],
|
338 |
+
help="Additional options to pass to the config file",
|
339 |
+
)
|
340 |
+
args = parser.parse_args()
|
341 |
+
|
342 |
+
main(args.assets_dir, args.cfg_path, args.options)
|
assets/Lazuli_Bunting_yell-YELLLAZB20160625SM303143.mp3
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6a67960286021e58ffab2d3e4b67b7e20d08b530018c64c6afefe4aae5ff28be
|
3 |
+
size 316920
|
assets/esp_favicon.png
ADDED
![]() |
assets/nri-GreenTreeFrogEvergladesNP.mp3
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3004b02bd1793db81f5e6ddfe2f805dbd587af3c0d03edbedec2ad23e92660dd
|
3 |
+
size 162234
|
assets/yell-YELLAMRO20160506SM3.mp3
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7a2700bbe2233505ccf592e9e06a4b196a0feb4d2d7a4773ed5f2f110696a001
|
3 |
+
size 598352
|
configs/inference.yml
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
llama_path: "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
3 |
+
|
4 |
+
freeze_beats: True
|
5 |
+
|
6 |
+
use_audio_Qformer: True
|
7 |
+
max_pooling: False
|
8 |
+
downsample_factor: 8
|
9 |
+
freeze_audio_QFormer: False
|
10 |
+
window_level_Qformer: True
|
11 |
+
num_audio_query_token: 1
|
12 |
+
second_per_window: 0.333333
|
13 |
+
second_stride: 0.333333
|
14 |
+
|
15 |
+
audio_llama_proj_model: ""
|
16 |
+
freeze_audio_llama_proj: False
|
17 |
+
|
18 |
+
lora: True
|
19 |
+
lora_rank: 32
|
20 |
+
lora_alpha: 32
|
21 |
+
lora_dropout: 0.1
|
22 |
+
|
23 |
+
prompt_template: "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
24 |
+
max_txt_len: 160
|
25 |
+
end_sym: <|end_of_text|>
|
26 |
+
|
27 |
+
beats_cfg:
|
28 |
+
input_patch_size: 16
|
29 |
+
embed_dim: 512
|
30 |
+
conv_bias: False
|
31 |
+
encoder_layers: 12
|
32 |
+
encoder_embed_dim: 768
|
33 |
+
encoder_ffn_embed_dim: 3072
|
34 |
+
encoder_attention_heads: 12
|
35 |
+
activation_fn: "gelu"
|
36 |
+
layer_wise_gradient_decay_ratio: 0.6
|
37 |
+
layer_norm_first: False
|
38 |
+
deep_norm: True
|
39 |
+
dropout: 0.0
|
40 |
+
attention_dropout: 0.0
|
41 |
+
activation_dropout: 0.0
|
42 |
+
encoder_layerdrop: 0.05
|
43 |
+
dropout_input: 0.0
|
44 |
+
conv_pos: 128
|
45 |
+
conv_pos_groups: 16
|
46 |
+
relative_position_embedding: True
|
47 |
+
num_buckets: 320
|
48 |
+
max_distance: 800
|
49 |
+
gru_rel_pos: True
|
50 |
+
finetuned_model: True
|
51 |
+
predictor_dropout: 0.0
|
52 |
+
predictor_class: 527
|
53 |
+
|
54 |
+
generate:
|
55 |
+
max_new_tokens: 300
|
56 |
+
num_beams: 2
|
57 |
+
do_sample: False
|
58 |
+
min_length: 1
|
59 |
+
temperature: 0.1
|
60 |
+
repetition_penalty: 1.0
|
61 |
+
length_penalty: 1.0
|
requirements.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
git+https://github.com/Bomme/NatureLM-audio.git
|