Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -17,7 +17,7 @@ model_id = "microsoft/Phi-3-mini-4k-instruct"
|
|
17 |
client = InferenceClient(model_id)
|
18 |
save_dataset_hf_token = os.environ.get("SAVE_DATASET_HF_TOKEN")
|
19 |
|
20 |
-
MAX_TOTAL_NB_ITEMS = 100
|
21 |
MAX_NB_ITEMS_PER_GENERATION_CALL = 10
|
22 |
NUM_ROWS = 100
|
23 |
NUM_VARIANTS = 10
|
@@ -25,7 +25,7 @@ NAMESPACE = "infinite-dataset-hub"
|
|
25 |
URL = "https://huggingface.co/spaces/infinite-dataset-hub/infinite-dataset-hub"
|
26 |
|
27 |
GENERATE_DATASET_NAMES_FOR_SEARCH_QUERY = (
|
28 |
-
"A Machine Learning
|
29 |
f"Generate a list of {MAX_NB_ITEMS_PER_GENERATION_CALL} names of quality datasets that don't exist but sound plausible and would "
|
30 |
"be helpful. Feel free to reuse words from the query '{search_query}' to name the datasets. "
|
31 |
"Every dataset should be about '{search_query}' and have descriptive tags/keywords including the ML task name associated with the dataset (classification, regression, anomaly detection, etc.). Use the following format:\n1. DatasetName1 (tag1, tag2, tag3)\n1. DatasetName2 (tag1, tag2, tag3)"
|
@@ -39,8 +39,6 @@ GENERATE_DATASET_CONTENT_FOR_SEARCH_QUERY_AND_NAME_AND_TAGS = (
|
|
39 |
"Reply using a short description of the dataset with title **Dataset Description:** followed by the CSV content in a code block and with title **CSV Content Preview:**."
|
40 |
)
|
41 |
|
42 |
-
|
43 |
-
|
44 |
GENERATE_MORE_ROWS = "Can you give me 10 additional samples in CSV format as well? Use the same CSV header '{csv_header}'."
|
45 |
GENERATE_VARIANTS_WITH_RARITY_AND_LABEL = "Focus on generating samples for the label '{label}' and ideally generate {rarity} samples."
|
46 |
GENERATE_VARIANTS_WITH_RARITY = "Focus on generating {rarity} samples."
|
@@ -71,8 +69,9 @@ landing_page_datasets_generated_text = """
|
|
71 |
9. HealthVitalSigns (anomaly detection, biometrics, prediction)
|
72 |
10. GameStockPredict (classification, finance, sports contingency)
|
73 |
"""
|
|
|
74 |
default_output = landing_page_datasets_generated_text.strip().split("\n")
|
75 |
-
assert default_output
|
76 |
|
77 |
DATASET_CARD_CONTENT = """
|
78 |
---
|
@@ -172,31 +171,28 @@ a {
|
|
172 |
"""
|
173 |
|
174 |
with gr.Blocks(css=css) as demo:
|
175 |
-
# Initialize state
|
176 |
generated_texts_state = gr.State((landing_page_datasets_generated_text,))
|
177 |
-
|
178 |
with gr.Column() as search_page:
|
179 |
with gr.Row():
|
180 |
with gr.Column(scale=10):
|
181 |
gr.Markdown(
|
182 |
-
"# 🤗
|
183 |
"An endless catalog of datasets, created just for you by an AI model.\n\n"
|
184 |
)
|
185 |
with gr.Row():
|
186 |
search_bar = gr.Textbox(
|
187 |
-
max_lines=1,
|
188 |
-
placeholder="Search datasets, get infinite results",
|
189 |
-
show_label=False,
|
190 |
-
container=False,
|
191 |
scale=9
|
192 |
)
|
193 |
search_button = gr.Button("🔍", variant="primary", scale=1)
|
194 |
-
|
195 |
-
# Initialize button groups and buttons
|
196 |
button_groups: list[gr.Group] = []
|
197 |
buttons: list[gr.Button] = []
|
198 |
-
|
199 |
-
# You'll need to define default_output before this loop
|
200 |
for i in range(MAX_TOTAL_NB_ITEMS):
|
201 |
if i < len(default_output):
|
202 |
line = default_output[i]
|
@@ -209,35 +205,32 @@ with gr.Blocks(css=css) as demo:
|
|
209 |
group_classes = "buttonsGroup insivibleButtonGroup"
|
210 |
dataset_name_classes = "topButton linear-background"
|
211 |
tags_classes = "bottomButton linear-background"
|
212 |
-
|
213 |
with gr.Group(elem_classes=group_classes) as button_group:
|
214 |
button_groups.append(button_group)
|
215 |
buttons.append(gr.Button(dataset_name, elem_classes=dataset_name_classes))
|
216 |
buttons.append(gr.Button(tags, elem_classes=tags_classes))
|
217 |
|
218 |
-
load_more_datasets = gr.Button("Load more datasets")
|
219 |
gr.Markdown(f"_powered by [{model_id}](https://huggingface.co/{model_id})_")
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
visible=False
|
232 |
-
)
|
233 |
gr.Markdown("Save datasets as public or private datasets")
|
234 |
visibility_radio = gr.Radio(
|
235 |
-
["public", "private"],
|
236 |
-
value="public",
|
237 |
-
container=False,
|
238 |
interactive=False
|
239 |
)
|
240 |
-
|
241 |
with gr.Column(visible=False) as dataset_page:
|
242 |
gr.Markdown(
|
243 |
"# 🤗 Infinite Dataset Hub ♾️\n\n"
|
@@ -254,527 +247,6 @@ with gr.Blocks(css=css) as demo:
|
|
254 |
dataset_share_textbox = gr.Textbox(visible=False, show_copy_button=True, label="Copy this URL:", interactive=False, show_label=True)
|
255 |
back_button = gr.Button("< Back", size="sm")
|
256 |
|
257 |
-
|
258 |
-
|
259 |
-
###################################
|
260 |
-
#
|
261 |
-
# Utils
|
262 |
-
#
|
263 |
-
###################################
|
264 |
-
|
265 |
-
T = TypeVar("T")
|
266 |
-
|
267 |
-
def batched(it: Iterable[T], n: int) -> Iterator[list[T]]:
|
268 |
-
"""Batch iterator into chunks of size n."""
|
269 |
-
it = iter(it)
|
270 |
-
while batch := list(islice(it, n)):
|
271 |
-
yield batch
|
272 |
-
|
273 |
-
def stream_response(msg: str, generated_texts: tuple[str] = (), max_tokens=500) -> Iterator[str]:
|
274 |
-
"""Stream response from chat completion API."""
|
275 |
-
messages = [
|
276 |
-
{"role": "user", "content": msg}
|
277 |
-
] + [
|
278 |
-
item
|
279 |
-
for generated_text in generated_texts
|
280 |
-
for item in [
|
281 |
-
{"role": "assistant", "content": generated_text},
|
282 |
-
{"role": "user", "content": "Can you generate more?"},
|
283 |
-
]
|
284 |
-
]
|
285 |
-
|
286 |
-
for _ in range(3): # Retry logic
|
287 |
-
try:
|
288 |
-
for message in client.chat_completion(
|
289 |
-
messages=messages,
|
290 |
-
max_tokens=max_tokens,
|
291 |
-
stream=True,
|
292 |
-
top_p=0.8,
|
293 |
-
seed=42,
|
294 |
-
):
|
295 |
-
yield message.choices[0].delta.content
|
296 |
-
break
|
297 |
-
except requests.exceptions.ConnectionError as e:
|
298 |
-
logger.warning(f"Connection error: {e}\nRetrying in 1sec")
|
299 |
-
time.sleep(1)
|
300 |
-
|
301 |
-
def gen_datasets_line_by_line(search_query: str, generated_texts: tuple[str] = ()) -> Iterator[str]:
|
302 |
-
"""Generate dataset names line by line based on search query."""
|
303 |
-
search_query = (search_query or "")[:1000].strip()
|
304 |
-
generated_text = ""
|
305 |
-
current_line = ""
|
306 |
-
|
307 |
-
for token in stream_response(
|
308 |
-
GENERATE_DATASET_NAMES_FOR_SEARCH_QUERY.format(search_query=search_query),
|
309 |
-
generated_texts=generated_texts,
|
310 |
-
):
|
311 |
-
current_line += token
|
312 |
-
if current_line.endswith("\n"):
|
313 |
-
yield current_line
|
314 |
-
generated_text += current_line
|
315 |
-
current_line = ""
|
316 |
-
|
317 |
-
if current_line:
|
318 |
-
yield current_line
|
319 |
-
generated_text += current_line
|
320 |
-
|
321 |
-
logger.debug(f"Generated text:\n{generated_text}")
|
322 |
-
|
323 |
-
def gen_dataset_content(search_query: str, dataset_name: str, tags: str) -> Iterator[str]:
|
324 |
-
"""Generate dataset content based on search query, name and tags."""
|
325 |
-
search_query = (search_query or "")[:1000].strip()
|
326 |
-
generated_text = ""
|
327 |
-
|
328 |
-
for token in stream_response(
|
329 |
-
GENERATE_DATASET_CONTENT_FOR_SEARCH_QUERY_AND_NAME_AND_TAGS.format(
|
330 |
-
search_query=search_query,
|
331 |
-
dataset_name=dataset_name,
|
332 |
-
tags=tags,
|
333 |
-
),
|
334 |
-
max_tokens=1500
|
335 |
-
):
|
336 |
-
generated_text += token
|
337 |
-
yield generated_text
|
338 |
-
|
339 |
-
logger.debug(f"Generated content:\n{generated_text}")
|
340 |
-
|
341 |
-
def _write_generator_to_queue(queue: Queue, func: Callable[..., Iterable], kwargs: dict) -> None:
|
342 |
-
"""Helper function to write generator output to queue."""
|
343 |
-
try:
|
344 |
-
for result in func(**kwargs):
|
345 |
-
queue.put(result)
|
346 |
-
except Exception as e:
|
347 |
-
logger.error(f"Error in generator: {e}")
|
348 |
-
queue.put(None)
|
349 |
-
|
350 |
-
def iflatmap_unordered(
|
351 |
-
func: Callable[..., Iterable[T]],
|
352 |
-
*,
|
353 |
-
kwargs_iterable: Iterable[dict],
|
354 |
-
) -> Iterable[T]:
|
355 |
-
"""Execute generator function with multiple kwargs in parallel."""
|
356 |
-
queue = Queue()
|
357 |
-
with ThreadPool() as pool:
|
358 |
-
async_results = [
|
359 |
-
pool.apply_async(_write_generator_to_queue, (queue, func, kwargs))
|
360 |
-
for kwargs in kwargs_iterable
|
361 |
-
]
|
362 |
-
try:
|
363 |
-
while True:
|
364 |
-
try:
|
365 |
-
result = queue.get(timeout=0.05)
|
366 |
-
if result is not None:
|
367 |
-
yield result
|
368 |
-
except Empty:
|
369 |
-
if all(result.ready() for result in async_results) and queue.empty():
|
370 |
-
break
|
371 |
-
finally:
|
372 |
-
for result in async_results:
|
373 |
-
try:
|
374 |
-
result.get(timeout=0.05)
|
375 |
-
except Exception as e:
|
376 |
-
logger.error(f"Async result error: {e}")
|
377 |
-
|
378 |
-
def generate_partial_dataset(
|
379 |
-
title: str,
|
380 |
-
content: str,
|
381 |
-
search_query: str,
|
382 |
-
variant: str,
|
383 |
-
csv_header: str,
|
384 |
-
output: list[Dict[str, str]],
|
385 |
-
indices_to_generate: list[int],
|
386 |
-
max_tokens=1500
|
387 |
-
) -> Iterator[int]:
|
388 |
-
"""Generate partial dataset with specific variants."""
|
389 |
-
try:
|
390 |
-
dataset_name, tags = title.strip("# ").split("\ntags:", 1)
|
391 |
-
dataset_name, tags = dataset_name.strip(), tags.strip()
|
392 |
-
|
393 |
-
messages = [
|
394 |
-
{
|
395 |
-
"role": "user",
|
396 |
-
"content": GENERATE_DATASET_CONTENT_FOR_SEARCH_QUERY_AND_NAME_AND_TAGS.format(
|
397 |
-
dataset_name=dataset_name,
|
398 |
-
tags=tags,
|
399 |
-
search_query=search_query,
|
400 |
-
)
|
401 |
-
},
|
402 |
-
{"role": "assistant", "content": f"{title}\n\n{content}"},
|
403 |
-
{"role": "user", "content": f"{GENERATE_MORE_ROWS.format(csv_header=csv_header)} {variant}"},
|
404 |
-
]
|
405 |
-
|
406 |
-
for response in _generate_dataset_rows(messages, max_tokens, indices_to_generate, output, csv_header):
|
407 |
-
yield response
|
408 |
-
|
409 |
-
except Exception as e:
|
410 |
-
logger.error(f"Error generating partial dataset: {e}")
|
411 |
-
yield 0
|
412 |
-
|
413 |
-
def _generate_dataset_rows(messages: list, max_tokens: int, indices: list, output: list, csv_header: str) -> Iterator[int]:
|
414 |
-
"""Helper function to generate dataset rows."""
|
415 |
-
for _ in range(3): # Retry logic
|
416 |
-
try:
|
417 |
-
return _process_generation(
|
418 |
-
messages, max_tokens, indices, output, csv_header
|
419 |
-
)
|
420 |
-
except requests.exceptions.ConnectionError as e:
|
421 |
-
logger.warning(f"Connection error: {e}\nRetrying in 1sec")
|
422 |
-
time.sleep(1)
|
423 |
-
return iter([])
|
424 |
-
|
425 |
-
def generate_variants(preview_df: pd.DataFrame) -> list[str]:
|
426 |
-
"""Generate variants based on preview dataframe."""
|
427 |
-
label_candidate_columns = [
|
428 |
-
column for column in preview_df.columns
|
429 |
-
if "label" in column.lower()
|
430 |
-
]
|
431 |
-
|
432 |
-
if label_candidate_columns:
|
433 |
-
labels = preview_df[label_candidate_columns[0]].unique()
|
434 |
-
if len(labels) > 1:
|
435 |
-
return [
|
436 |
-
GENERATE_VARIANTS_WITH_RARITY_AND_LABEL.format(
|
437 |
-
rarity=rarity,
|
438 |
-
label=label
|
439 |
-
)
|
440 |
-
for rarity in RARITIES
|
441 |
-
for label in labels
|
442 |
-
]
|
443 |
-
|
444 |
-
return [
|
445 |
-
GENERATE_VARIANTS_WITH_RARITY.format(rarity=rarity)
|
446 |
-
for rarity in LONG_RARITIES
|
447 |
-
]
|
448 |
-
|
449 |
-
###################################
|
450 |
-
#
|
451 |
-
# Buttons
|
452 |
-
#
|
453 |
-
###################################
|
454 |
-
|
455 |
-
|
456 |
-
def _search_datasets(search_query):
|
457 |
-
yield {generated_texts_state: []}
|
458 |
-
yield {
|
459 |
-
button_group: gr.Group(elem_classes="buttonsGroup insivibleButtonGroup")
|
460 |
-
for button_group in button_groups[MAX_NB_ITEMS_PER_GENERATION_CALL:]
|
461 |
-
}
|
462 |
-
yield {
|
463 |
-
k: v
|
464 |
-
for dataset_name_button, tags_button in batched(buttons, 2)
|
465 |
-
for k, v in {
|
466 |
-
dataset_name_button: gr.Button("⬜⬜⬜⬜⬜⬜", elem_classes="topButton linear-background"),
|
467 |
-
tags_button: gr.Button("░░░░, ░░░░, ░░░░", elem_classes="bottomButton linear-background")
|
468 |
-
}.items()
|
469 |
-
}
|
470 |
-
current_item_idx = 0
|
471 |
-
generated_text = ""
|
472 |
-
for line in gen_datasets_line_by_line(search_query):
|
473 |
-
if "I'm sorry" in line or "against Microsoft's use case policy" in line:
|
474 |
-
raise gr.Error("Error: inappropriate content")
|
475 |
-
if current_item_idx >= MAX_NB_ITEMS_PER_GENERATION_CALL:
|
476 |
-
return
|
477 |
-
if line.strip() and line.strip().split(".", 1)[0].isnumeric():
|
478 |
-
try:
|
479 |
-
dataset_name, tags = line.strip().split(".", 1)[1].strip(" )").split(" (", 1)
|
480 |
-
except ValueError:
|
481 |
-
dataset_name, tags = line.strip().split(".", 1)[1].strip(" )").split(" ", 1)
|
482 |
-
dataset_name, tags = dataset_name.strip("()[]* "), tags.strip("()[]* ")
|
483 |
-
generated_text += line
|
484 |
-
yield {
|
485 |
-
buttons[2 * current_item_idx]: gr.Button(dataset_name, elem_classes="topButton"),
|
486 |
-
buttons[2 * current_item_idx + 1]: gr.Button(tags, elem_classes="bottomButton"),
|
487 |
-
generated_texts_state: (generated_text,),
|
488 |
-
}
|
489 |
-
current_item_idx += 1
|
490 |
-
|
491 |
-
|
492 |
-
|
493 |
-
@search_button.click(inputs=search_bar, outputs=button_groups + buttons + [generated_texts_state])
|
494 |
-
def search_dataset_from_search_button(search_query):
|
495 |
-
yield from _search_datasets(search_query)
|
496 |
-
|
497 |
-
|
498 |
-
@search_bar.submit(inputs=search_bar, outputs=button_groups + buttons + [generated_texts_state])
|
499 |
-
def search_dataset_from_search_bar(search_query):
|
500 |
-
yield from _search_datasets(search_query)
|
501 |
-
|
502 |
-
|
503 |
-
@load_more_datasets.click(inputs=[search_bar, generated_texts_state], outputs=button_groups + buttons + [generated_texts_state])
|
504 |
-
def search_more_datasets(search_query, generated_texts):
|
505 |
-
current_item_idx = initial_item_idx = len(generated_texts) * MAX_NB_ITEMS_PER_GENERATION_CALL
|
506 |
-
yield {
|
507 |
-
button_group: gr.Group(elem_classes="buttonsGroup")
|
508 |
-
for button_group in button_groups[len(generated_texts) * MAX_NB_ITEMS_PER_GENERATION_CALL:(len(generated_texts) + 1) * MAX_NB_ITEMS_PER_GENERATION_CALL]
|
509 |
-
}
|
510 |
-
generated_text = ""
|
511 |
-
for line in gen_datasets_line_by_line(search_query, generated_texts=generated_texts):
|
512 |
-
if "I'm sorry" in line or "against Microsoft's use case policy" in line:
|
513 |
-
raise gr.Error("Error: inappropriate content")
|
514 |
-
if current_item_idx - initial_item_idx >= MAX_NB_ITEMS_PER_GENERATION_CALL:
|
515 |
-
return
|
516 |
-
if line.strip() and line.strip().split(".", 1)[0].isnumeric():
|
517 |
-
try:
|
518 |
-
dataset_name, tags = line.strip().split(".", 1)[1].strip(" )").split(" (", 1)
|
519 |
-
except ValueError:
|
520 |
-
dataset_name, tags = line.strip().split(".", 1)[1].strip(" )").split(" ", 1) [0], ""
|
521 |
-
dataset_name, tags = dataset_name.strip("()[]* "), tags.strip("()[]* ")
|
522 |
-
generated_text += line
|
523 |
-
yield {
|
524 |
-
buttons[2 * current_item_idx]: gr.Button(dataset_name, elem_classes="topButton"),
|
525 |
-
buttons[2 * current_item_idx + 1]: gr.Button(tags, elem_classes="bottomButton"),
|
526 |
-
generated_texts_state: (*generated_texts, generated_text),
|
527 |
-
}
|
528 |
-
current_item_idx += 1
|
529 |
-
|
530 |
-
def _show_dataset(search_query, dataset_name, tags):
|
531 |
-
yield {
|
532 |
-
search_page: gr.Column(visible=False),
|
533 |
-
dataset_page: gr.Column(visible=True),
|
534 |
-
dataset_title: f"# {dataset_name}\n\n tags: {tags}",
|
535 |
-
dataset_share_textbox: gr.Textbox(visible=False),
|
536 |
-
dataset_dataframe: gr.DataFrame(visible=False),
|
537 |
-
generate_full_dataset_button: gr.Button(interactive=True),
|
538 |
-
save_dataset_button: gr.Button(visible=False),
|
539 |
-
open_dataset_message: gr.Markdown(visible=False)
|
540 |
-
}
|
541 |
-
for generated_text in gen_dataset_content(search_query=search_query, dataset_name=dataset_name, tags=tags):
|
542 |
-
yield {dataset_content: generated_text}
|
543 |
-
|
544 |
-
|
545 |
-
show_dataset_inputs = [search_bar, *buttons]
|
546 |
-
show_dataset_outputs = [search_page, dataset_page, dataset_title, dataset_content, generate_full_dataset_button, dataset_dataframe, save_dataset_button, open_dataset_message, dataset_share_textbox]
|
547 |
-
scroll_to_top_js = """
|
548 |
-
function (...args) {
|
549 |
-
console.log(args);
|
550 |
-
if ('parentIFrame' in window) {
|
551 |
-
window.parentIFrame.scrollTo({top: 0, behavior:'smooth'});
|
552 |
-
} else {
|
553 |
-
window.scrollTo({ top: 0 });
|
554 |
-
}
|
555 |
-
return args;
|
556 |
-
}
|
557 |
-
"""
|
558 |
-
|
559 |
-
def show_dataset_from_button(search_query, *buttons_values, i):
|
560 |
-
dataset_name, tags = buttons_values[2 * i : 2 * i + 2]
|
561 |
-
yield from _show_dataset(search_query, dataset_name, tags)
|
562 |
-
|
563 |
-
for i, (dataset_name_button, tags_button) in enumerate(batched(buttons, 2)):
|
564 |
-
dataset_name_button.click(partial(show_dataset_from_button, i=i), inputs=show_dataset_inputs, outputs=show_dataset_outputs, js=scroll_to_top_js)
|
565 |
-
tags_button.click(partial(show_dataset_from_button, i=i), inputs=show_dataset_inputs, outputs=show_dataset_outputs, js=scroll_to_top_js)
|
566 |
-
|
567 |
-
|
568 |
-
@back_button.click(outputs=[search_page, dataset_page], js=scroll_to_top_js)
|
569 |
-
def show_search_page():
|
570 |
-
return gr.Column(visible=True), gr.Column(visible=False)
|
571 |
-
|
572 |
-
|
573 |
-
@generate_full_dataset_button.click(inputs=[dataset_title, dataset_content, search_bar, select_namespace_dropdown, visibility_radio], outputs=[dataset_dataframe, generate_full_dataset_button, save_dataset_button])
|
574 |
-
def generate_full_dataset(title, content, search_query, namespace, visability):
|
575 |
-
dataset_name, tags = title.strip("# ").split("\ntags:", 1)
|
576 |
-
dataset_name, tags = dataset_name.strip(), tags.strip()
|
577 |
-
csv_header, preview_df = parse_preview_df(content)
|
578 |
-
# Remove dummy "id" columns
|
579 |
-
for column_name, values in preview_df.to_dict(orient="series").items():
|
580 |
-
try:
|
581 |
-
if [int(v) for v in values] == list(range(len(preview_df))):
|
582 |
-
preview_df = preview_df.drop(columns=column_name)
|
583 |
-
if [int(v) for v in values] == list(range(1, len(preview_df) + 1)):
|
584 |
-
preview_df = preview_df.drop(columns=column_name)
|
585 |
-
except Exception:
|
586 |
-
pass
|
587 |
-
columns = list(preview_df)
|
588 |
-
output: list[Optional[dict]] = [None] * NUM_ROWS
|
589 |
-
output[:len(preview_df)] = [{"idx": i, **x} for i, x in enumerate(preview_df.to_dict(orient="records"))]
|
590 |
-
yield {
|
591 |
-
dataset_dataframe: gr.DataFrame(pd.DataFrame([{"idx": i, **x} for i, x in enumerate(output) if x]), visible=True),
|
592 |
-
generate_full_dataset_button: gr.Button(interactive=False),
|
593 |
-
save_dataset_button: gr.Button(f"💾 Save Dataset {namespace}/{dataset_name}" + (" (private)" if visability != "public" else ""), visible=True, interactive=False)
|
594 |
-
}
|
595 |
-
kwargs_iterable = [
|
596 |
-
{
|
597 |
-
"title": title,
|
598 |
-
"content": content,
|
599 |
-
"search_query": search_query,
|
600 |
-
"variant": variant,
|
601 |
-
"csv_header": csv_header,
|
602 |
-
"output": output,
|
603 |
-
"indices_to_generate": list(range(len(preview_df) + i, NUM_ROWS, NUM_VARIANTS)),
|
604 |
-
}
|
605 |
-
for i, variant in enumerate(islice(generate_variants(preview_df), NUM_VARIANTS))
|
606 |
-
]
|
607 |
-
for _ in iflatmap_unordered(generate_partial_dataset, kwargs_iterable=kwargs_iterable):
|
608 |
-
yield {dataset_dataframe: pd.DataFrame([{"idx": i, **{column_name: x.get(column_name) for column_name in columns}} for i, x in enumerate(output) if x])}
|
609 |
-
yield {save_dataset_button: gr.Button(interactive=True)}
|
610 |
-
print(f"Generated {dataset_name}!")
|
611 |
-
|
612 |
-
|
613 |
-
@save_dataset_button.click(inputs=[dataset_title, dataset_content, search_bar, dataset_dataframe, select_namespace_dropdown, visibility_radio], outputs=[save_dataset_button, open_dataset_message])
|
614 |
-
def save_dataset(title: str, content: str, search_query: str, df: pd.DataFrame, namespace: str, visability: str, oauth_token: Optional[gr.OAuthToken]):
|
615 |
-
dataset_name, tags = title.strip("# ").split("\ntags:", 1)
|
616 |
-
dataset_name, tags = dataset_name.strip(), tags.strip()
|
617 |
-
token = oauth_token.token if oauth_token else save_dataset_hf_token
|
618 |
-
repo_id = f"{namespace}/{dataset_name}"
|
619 |
-
dataset_url = f"{URL}?q={search_query.replace(' ', '+')}&dataset={dataset_name.replace(' ', '+')}&tags={tags.replace(' ', '+')}"
|
620 |
-
gr.Info("Saving dataset...")
|
621 |
-
yield {save_dataset_button: gr.Button(interactive=False)}
|
622 |
-
create_repo(repo_id=repo_id, repo_type="dataset", private=visability!="public", exist_ok=True, token=token)
|
623 |
-
df.to_csv(f"hf://datasets/{repo_id}/data.csv", storage_options={"token": token}, index=False)
|
624 |
-
DatasetCard(DATASET_CARD_CONTENT.format(title=title, content=content, url=URL, dataset_url=dataset_url, model_id=model_id, search_query=search_query)).push_to_hub(repo_id=repo_id, repo_type="dataset", token=token)
|
625 |
-
gr.Info(f"✅ Dataset saved at {repo_id}")
|
626 |
-
additional_message = "PS: You can also save datasets under your account in the Settings ;)"
|
627 |
-
yield {open_dataset_message: gr.Markdown(f"# 🎉 Yay ! Your dataset has been saved to [{repo_id}](https://huggingface.co/datasets/{repo_id}) !\n\nDataset link: [https://huggingface.co/datasets/{repo_id}](https://huggingface.co/datasets/{repo_id})\n\n{additional_message}", visible=True)}
|
628 |
-
print(f"Saved {dataset_name}!")
|
629 |
-
|
630 |
-
|
631 |
-
@dataset_share_button.click(inputs=[dataset_title, search_bar], outputs=[dataset_share_textbox])
|
632 |
-
def show_dataset_url(title, search_query):
|
633 |
-
dataset_name, tags = title.strip("# ").split("\ntags:", 1)
|
634 |
-
dataset_name, tags = dataset_name.strip(), tags.strip()
|
635 |
-
return gr.Textbox(
|
636 |
-
f"{URL}?q={search_query.replace(' ', '+')}&dataset={dataset_name.replace(' ', '+')}&tags={tags.replace(' ', '+')}",
|
637 |
-
visible=True,
|
638 |
-
)
|
639 |
-
|
640 |
-
@demo.load(outputs=[dataset_title, dataset_content, dataset_dataframe, search_bar])
|
641 |
-
def load_app(request: gr.Request, oauth_token: Optional[gr.OAuthToken]):
|
642 |
-
if oauth_token:
|
643 |
-
user_info = whoami(oauth_token.token)
|
644 |
-
yield {
|
645 |
-
select_namespace_dropdown: gr.Dropdown(
|
646 |
-
choices=[user_info["name"]] + [org_info["name"] for org_info in user_info["orgs"]],
|
647 |
-
value=user_info["name"],
|
648 |
-
visible=True,
|
649 |
-
),
|
650 |
-
visibility_radio: gr.Radio(interactive=True),
|
651 |
-
}
|
652 |
-
query_params = dict(request.query_params)
|
653 |
-
if "dataset" in query_params:
|
654 |
-
yield from _show_dataset(
|
655 |
-
search_query=query_params.get("q", query_params["dataset"]),
|
656 |
-
dataset_name=query_params["dataset"],
|
657 |
-
tags=query_params.get("tags", "")
|
658 |
-
)
|
659 |
-
elif "q" in query_params:
|
660 |
-
yield {search_bar: query_params["q"]}
|
661 |
-
yield from _search_datasets(query_params["q"])
|
662 |
-
else:
|
663 |
-
# Default behavior
|
664 |
-
yield {file_uploader: gr.File(visible=True)}
|
665 |
-
|
666 |
-
@demo.upload(
|
667 |
-
inputs=[search_bar, dataset_title, dataset_content, dataset_dataframe, select_namespace_dropdown, visibility_radio],
|
668 |
-
outputs=[save_dataset_button, open_dataset_message]
|
669 |
-
)
|
670 |
-
def upload_dataset(search_query, dataset_name, dataset_content, df, namespace, visibility):
|
671 |
-
# Parse dataset name and tags
|
672 |
-
dataset_name, tags = dataset_name.strip("# ").split("\ntags:", 1)
|
673 |
-
dataset_name, tags = dataset_name.strip(), tags.strip()
|
674 |
-
|
675 |
-
#Create local directory structure
|
676 |
-
base_dir = os.path.join(os.getcwd(), "datasets")
|
677 |
-
dataset_dir = os.path.join(base_dir, dataset_name)
|
678 |
-
os.makedirs(dataset_dir, exist_ok=True)
|
679 |
-
|
680 |
-
# Parse and clean preview dataframe
|
681 |
-
csv_header, preview_df = parse_preview_df(dataset_content)
|
682 |
-
|
683 |
-
# Remove dummy "id" columns
|
684 |
-
for column_name, values in preview_df.to_dict(orient="series").items():
|
685 |
-
try:
|
686 |
-
if [int(v) for v in values] == list(range(len(preview_df))):
|
687 |
-
preview_df = preview_df.drop(columns=column_name)
|
688 |
-
if [int(v) for v in values] == list(range(1, len(preview_df) + 1)):
|
689 |
-
preview_df = preview_df.drop(columns=column_name)
|
690 |
-
except Exception:
|
691 |
-
pass
|
692 |
-
|
693 |
-
columns = list(preview_df)
|
694 |
-
output: list[Optional[dict]] = [None] * NUM_ROWS
|
695 |
-
output[:len(preview_df)] = [{"idx": i, **x} for i, x in enumerate(preview_df.to_dict(orient="records"))]
|
696 |
-
|
697 |
-
# Update UI to show upload progress
|
698 |
-
yield {
|
699 |
-
save_dataset_button: gr.Button(
|
700 |
-
f"💾 Save Dataset {namespace}/{dataset_name}" +
|
701 |
-
(" (private)" if visibility != "public" else ""),
|
702 |
-
interactive=False
|
703 |
-
),
|
704 |
-
open_dataset_message: gr.Markdown(f"Uploading dataset {dataset_name}...")
|
705 |
-
}
|
706 |
-
|
707 |
-
try:
|
708 |
-
# Get authentication token
|
709 |
-
token = oauth_token.token if oauth_token else save_dataset_hf_token
|
710 |
-
repo_id = f"{namespace}/{dataset_name}"
|
711 |
-
|
712 |
-
# Save files locally first
|
713 |
-
local_csv_path = os.path.join(dataset_dir, "data.csv")
|
714 |
-
local_readme_path = os.path.join(dataset_dir, "README.md")
|
715 |
-
|
716 |
-
# Save CSV locally
|
717 |
-
df.to_csv(local_csv_path, index=False)
|
718 |
-
|
719 |
-
# Create dataset card content
|
720 |
-
dataset_url = f"{URL}?q={search_query.replace(' ', '+')}&dataset={dataset_name.replace(' ', '+')}&tags={tags.replace(' ', '+')}"
|
721 |
-
dataset_card = DatasetCard(
|
722 |
-
DATASET_CARD_CONTENT.format(
|
723 |
-
title=title,
|
724 |
-
content=content,
|
725 |
-
url=URL,
|
726 |
-
dataset_url=dataset_url,
|
727 |
-
model_id=model_id,
|
728 |
-
search_query=search_query
|
729 |
-
)
|
730 |
-
)
|
731 |
-
|
732 |
-
# Save README locally
|
733 |
-
with open(local_readme_path, 'w', encoding='utf-8') as f:
|
734 |
-
f.write(str(dataset_card))
|
735 |
-
|
736 |
-
# Create and upload to Hub
|
737 |
-
create_repo(
|
738 |
-
repo_id=repo_id,
|
739 |
-
repo_type="dataset",
|
740 |
-
private=visibility != "public",
|
741 |
-
exist_ok=True,
|
742 |
-
token=token
|
743 |
-
)
|
744 |
-
|
745 |
-
# Upload files to Hub
|
746 |
-
df.to_csv(
|
747 |
-
f"hf://datasets/{repo_id}/data.csv",
|
748 |
-
storage_options={"token": token},
|
749 |
-
index=False
|
750 |
-
)
|
751 |
-
dataset_card.push_to_hub(
|
752 |
-
repo_id=repo_id,
|
753 |
-
repo_type="dataset",
|
754 |
-
token=token
|
755 |
-
)
|
756 |
-
|
757 |
-
# Show success message
|
758 |
-
gr.Info(f"✅ Dataset saved at {repo_id}")
|
759 |
-
additional_message = "PS: You can also save datasets under your account in the Settings ;)"
|
760 |
-
yield {
|
761 |
-
open_dataset_message: gr.Markdown(
|
762 |
-
f"# 🎉 Yay ! Your dataset has been saved to [{repo_id}](https://huggingface.co/datasets/{repo_id}) !\n\n"
|
763 |
-
f"Dataset link: [https://huggingface.co/datasets/{repo_id}](https://huggingface.co/datasets/{repo_id})\n\n"
|
764 |
-
f"{additional_message}",
|
765 |
-
visible=True
|
766 |
-
)
|
767 |
-
}
|
768 |
-
print(f"Saved {dataset_name}!")
|
769 |
-
|
770 |
-
except Exception as e:
|
771 |
-
print(f"Error saving dataset: {e}")
|
772 |
-
yield {
|
773 |
-
open_dataset_message: gr.Markdown(
|
774 |
-
f"❌ Error saving dataset: {str(e)}",
|
775 |
-
visible=True
|
776 |
-
)
|
777 |
-
}
|
778 |
|
779 |
-
|
780 |
-
demo.launch()
|
|
|
17 |
client = InferenceClient(model_id)
|
18 |
save_dataset_hf_token = os.environ.get("SAVE_DATASET_HF_TOKEN")
|
19 |
|
20 |
+
MAX_TOTAL_NB_ITEMS = 100
|
21 |
MAX_NB_ITEMS_PER_GENERATION_CALL = 10
|
22 |
NUM_ROWS = 100
|
23 |
NUM_VARIANTS = 10
|
|
|
25 |
URL = "https://huggingface.co/spaces/infinite-dataset-hub/infinite-dataset-hub"
|
26 |
|
27 |
GENERATE_DATASET_NAMES_FOR_SEARCH_QUERY = (
|
28 |
+
"A Machine Learning Practitioner is looking for a dataset that matches '{search_query}'. "
|
29 |
f"Generate a list of {MAX_NB_ITEMS_PER_GENERATION_CALL} names of quality datasets that don't exist but sound plausible and would "
|
30 |
"be helpful. Feel free to reuse words from the query '{search_query}' to name the datasets. "
|
31 |
"Every dataset should be about '{search_query}' and have descriptive tags/keywords including the ML task name associated with the dataset (classification, regression, anomaly detection, etc.). Use the following format:\n1. DatasetName1 (tag1, tag2, tag3)\n1. DatasetName2 (tag1, tag2, tag3)"
|
|
|
39 |
"Reply using a short description of the dataset with title **Dataset Description:** followed by the CSV content in a code block and with title **CSV Content Preview:**."
|
40 |
)
|
41 |
|
|
|
|
|
42 |
GENERATE_MORE_ROWS = "Can you give me 10 additional samples in CSV format as well? Use the same CSV header '{csv_header}'."
|
43 |
GENERATE_VARIANTS_WITH_RARITY_AND_LABEL = "Focus on generating samples for the label '{label}' and ideally generate {rarity} samples."
|
44 |
GENERATE_VARIANTS_WITH_RARITY = "Focus on generating {rarity} samples."
|
|
|
69 |
9. HealthVitalSigns (anomaly detection, biometrics, prediction)
|
70 |
10. GameStockPredict (classification, finance, sports contingency)
|
71 |
"""
|
72 |
+
|
73 |
default_output = landing_page_datasets_generated_text.strip().split("\n")
|
74 |
+
assert len(default_output) == MAX_NB_ITEMS_PER_GENERATION_CALL
|
75 |
|
76 |
DATASET_CARD_CONTENT = """
|
77 |
---
|
|
|
171 |
"""
|
172 |
|
173 |
with gr.Blocks(css=css) as demo:
|
|
|
174 |
generated_texts_state = gr.State((landing_page_datasets_generated_text,))
|
175 |
+
|
176 |
with gr.Column() as search_page:
|
177 |
with gr.Row():
|
178 |
with gr.Column(scale=10):
|
179 |
gr.Markdown(
|
180 |
+
"# 🤗 Infinite Dataset Hub ♾️\n\n"
|
181 |
"An endless catalog of datasets, created just for you by an AI model.\n\n"
|
182 |
)
|
183 |
with gr.Row():
|
184 |
search_bar = gr.Textbox(
|
185 |
+
max_lines=1,
|
186 |
+
placeholder="Search datasets, get infinite results",
|
187 |
+
show_label=False,
|
188 |
+
container=False,
|
189 |
scale=9
|
190 |
)
|
191 |
search_button = gr.Button("🔍", variant="primary", scale=1)
|
192 |
+
|
|
|
193 |
button_groups: list[gr.Group] = []
|
194 |
buttons: list[gr.Button] = []
|
195 |
+
|
|
|
196 |
for i in range(MAX_TOTAL_NB_ITEMS):
|
197 |
if i < len(default_output):
|
198 |
line = default_output[i]
|
|
|
205 |
group_classes = "buttonsGroup insivibleButtonGroup"
|
206 |
dataset_name_classes = "topButton linear-background"
|
207 |
tags_classes = "bottomButton linear-background"
|
208 |
+
|
209 |
with gr.Group(elem_classes=group_classes) as button_group:
|
210 |
button_groups.append(button_group)
|
211 |
buttons.append(gr.Button(dataset_name, elem_classes=dataset_name_classes))
|
212 |
buttons.append(gr.Button(tags, elem_classes=tags_classes))
|
213 |
|
214 |
+
load_more_datasets = gr.Button("Load more datasets")
|
215 |
gr.Markdown(f"_powered by [{model_id}](https://huggingface.co/{model_id})_")
|
216 |
+
|
217 |
+
with gr.Column(scale=4, min_width="200px"):
|
218 |
+
with gr.Accordion("Settings", open=False, elem_classes="settings"):
|
219 |
+
gr.Markdown("Save datasets to your account")
|
220 |
+
gr.LoginButton()
|
221 |
+
select_namespace_dropdown = gr.Dropdown(
|
222 |
+
choices=[NAMESPACE],
|
223 |
+
value=NAMESPACE,
|
224 |
+
label="Select user or organization",
|
225 |
+
visible=False
|
226 |
+
)
|
|
|
|
|
227 |
gr.Markdown("Save datasets as public or private datasets")
|
228 |
visibility_radio = gr.Radio(
|
229 |
+
["public", "private"],
|
230 |
+
value="public",
|
231 |
+
container=False,
|
232 |
interactive=False
|
233 |
)
|
|
|
234 |
with gr.Column(visible=False) as dataset_page:
|
235 |
gr.Markdown(
|
236 |
"# 🤗 Infinite Dataset Hub ♾️\n\n"
|
|
|
247 |
dataset_share_textbox = gr.Textbox(visible=False, show_copy_button=True, label="Copy this URL:", interactive=False, show_label=True)
|
248 |
back_button = gr.Button("< Back", size="sm")
|
249 |
|
250 |
+
# Define the remaining functions and event handlers...
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
251 |
|
252 |
+
demo.launch()
|
|