acecalisto3 commited on
Commit
cc30771
·
verified ·
1 Parent(s): d938019

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -558
app.py CHANGED
@@ -17,7 +17,7 @@ model_id = "microsoft/Phi-3-mini-4k-instruct"
17
  client = InferenceClient(model_id)
18
  save_dataset_hf_token = os.environ.get("SAVE_DATASET_HF_TOKEN")
19
 
20
- MAX_TOTAL_NB_ITEMS = 100 # almost infinite, don't judge me (actually it's because gradio needs a fixed number of components)
21
  MAX_NB_ITEMS_PER_GENERATION_CALL = 10
22
  NUM_ROWS = 100
23
  NUM_VARIANTS = 10
@@ -25,7 +25,7 @@ NAMESPACE = "infinite-dataset-hub"
25
  URL = "https://huggingface.co/spaces/infinite-dataset-hub/infinite-dataset-hub"
26
 
27
  GENERATE_DATASET_NAMES_FOR_SEARCH_QUERY = (
28
- "A Machine Learning Practioner is looking for a dataset that matches '{search_query}'. "
29
  f"Generate a list of {MAX_NB_ITEMS_PER_GENERATION_CALL} names of quality datasets that don't exist but sound plausible and would "
30
  "be helpful. Feel free to reuse words from the query '{search_query}' to name the datasets. "
31
  "Every dataset should be about '{search_query}' and have descriptive tags/keywords including the ML task name associated with the dataset (classification, regression, anomaly detection, etc.). Use the following format:\n1. DatasetName1 (tag1, tag2, tag3)\n1. DatasetName2 (tag1, tag2, tag3)"
@@ -39,8 +39,6 @@ GENERATE_DATASET_CONTENT_FOR_SEARCH_QUERY_AND_NAME_AND_TAGS = (
39
  "Reply using a short description of the dataset with title **Dataset Description:** followed by the CSV content in a code block and with title **CSV Content Preview:**."
40
  )
41
 
42
-
43
-
44
  GENERATE_MORE_ROWS = "Can you give me 10 additional samples in CSV format as well? Use the same CSV header '{csv_header}'."
45
  GENERATE_VARIANTS_WITH_RARITY_AND_LABEL = "Focus on generating samples for the label '{label}' and ideally generate {rarity} samples."
46
  GENERATE_VARIANTS_WITH_RARITY = "Focus on generating {rarity} samples."
@@ -71,8 +69,9 @@ landing_page_datasets_generated_text = """
71
  9. HealthVitalSigns (anomaly detection, biometrics, prediction)
72
  10. GameStockPredict (classification, finance, sports contingency)
73
  """
 
74
  default_output = landing_page_datasets_generated_text.strip().split("\n")
75
- assert default_output, "Output should not be empty."
76
 
77
  DATASET_CARD_CONTENT = """
78
  ---
@@ -172,31 +171,28 @@ a {
172
  """
173
 
174
  with gr.Blocks(css=css) as demo:
175
- # Initialize state
176
  generated_texts_state = gr.State((landing_page_datasets_generated_text,))
177
-
178
  with gr.Column() as search_page:
179
  with gr.Row():
180
  with gr.Column(scale=10):
181
  gr.Markdown(
182
- "# 🤗 inPHIni-D-set ♾️\n\n"
183
  "An endless catalog of datasets, created just for you by an AI model.\n\n"
184
  )
185
  with gr.Row():
186
  search_bar = gr.Textbox(
187
- max_lines=1,
188
- placeholder="Search datasets, get infinite results",
189
- show_label=False,
190
- container=False,
191
  scale=9
192
  )
193
  search_button = gr.Button("🔍", variant="primary", scale=1)
194
-
195
- # Initialize button groups and buttons
196
  button_groups: list[gr.Group] = []
197
  buttons: list[gr.Button] = []
198
-
199
- # You'll need to define default_output before this loop
200
  for i in range(MAX_TOTAL_NB_ITEMS):
201
  if i < len(default_output):
202
  line = default_output[i]
@@ -209,35 +205,32 @@ with gr.Blocks(css=css) as demo:
209
  group_classes = "buttonsGroup insivibleButtonGroup"
210
  dataset_name_classes = "topButton linear-background"
211
  tags_classes = "bottomButton linear-background"
212
-
213
  with gr.Group(elem_classes=group_classes) as button_group:
214
  button_groups.append(button_group)
215
  buttons.append(gr.Button(dataset_name, elem_classes=dataset_name_classes))
216
  buttons.append(gr.Button(tags, elem_classes=tags_classes))
217
 
218
- load_more_datasets = gr.Button("Load more datasets") # TODO: disable when reaching end of page
219
  gr.Markdown(f"_powered by [{model_id}](https://huggingface.co/{model_id})_")
220
-
221
- with gr.Column(scale=4, min_width="200px"):
222
- with gr.Accordion("Settings", open=False, elem_classes="settings"):
223
- gr.Markdown("Save datasets to your account")
224
- login_button = gr.LoginButton()
225
- login_button.activate() # This line fixes the warning
226
-
227
- select_namespace_dropdown = gr.Dropdown(
228
- choices=[NAMESPACE],
229
- value=NAMESPACE,
230
- label="Select user or organization",
231
- visible=False
232
- )
233
  gr.Markdown("Save datasets as public or private datasets")
234
  visibility_radio = gr.Radio(
235
- ["public", "private"],
236
- value="public",
237
- container=False,
238
  interactive=False
239
  )
240
-
241
  with gr.Column(visible=False) as dataset_page:
242
  gr.Markdown(
243
  "# 🤗 Infinite Dataset Hub ♾️\n\n"
@@ -254,527 +247,6 @@ with gr.Blocks(css=css) as demo:
254
  dataset_share_textbox = gr.Textbox(visible=False, show_copy_button=True, label="Copy this URL:", interactive=False, show_label=True)
255
  back_button = gr.Button("< Back", size="sm")
256
 
257
-
258
-
259
- ###################################
260
- #
261
- # Utils
262
- #
263
- ###################################
264
-
265
- T = TypeVar("T")
266
-
267
- def batched(it: Iterable[T], n: int) -> Iterator[list[T]]:
268
- """Batch iterator into chunks of size n."""
269
- it = iter(it)
270
- while batch := list(islice(it, n)):
271
- yield batch
272
-
273
- def stream_response(msg: str, generated_texts: tuple[str] = (), max_tokens=500) -> Iterator[str]:
274
- """Stream response from chat completion API."""
275
- messages = [
276
- {"role": "user", "content": msg}
277
- ] + [
278
- item
279
- for generated_text in generated_texts
280
- for item in [
281
- {"role": "assistant", "content": generated_text},
282
- {"role": "user", "content": "Can you generate more?"},
283
- ]
284
- ]
285
-
286
- for _ in range(3): # Retry logic
287
- try:
288
- for message in client.chat_completion(
289
- messages=messages,
290
- max_tokens=max_tokens,
291
- stream=True,
292
- top_p=0.8,
293
- seed=42,
294
- ):
295
- yield message.choices[0].delta.content
296
- break
297
- except requests.exceptions.ConnectionError as e:
298
- logger.warning(f"Connection error: {e}\nRetrying in 1sec")
299
- time.sleep(1)
300
-
301
- def gen_datasets_line_by_line(search_query: str, generated_texts: tuple[str] = ()) -> Iterator[str]:
302
- """Generate dataset names line by line based on search query."""
303
- search_query = (search_query or "")[:1000].strip()
304
- generated_text = ""
305
- current_line = ""
306
-
307
- for token in stream_response(
308
- GENERATE_DATASET_NAMES_FOR_SEARCH_QUERY.format(search_query=search_query),
309
- generated_texts=generated_texts,
310
- ):
311
- current_line += token
312
- if current_line.endswith("\n"):
313
- yield current_line
314
- generated_text += current_line
315
- current_line = ""
316
-
317
- if current_line:
318
- yield current_line
319
- generated_text += current_line
320
-
321
- logger.debug(f"Generated text:\n{generated_text}")
322
-
323
- def gen_dataset_content(search_query: str, dataset_name: str, tags: str) -> Iterator[str]:
324
- """Generate dataset content based on search query, name and tags."""
325
- search_query = (search_query or "")[:1000].strip()
326
- generated_text = ""
327
-
328
- for token in stream_response(
329
- GENERATE_DATASET_CONTENT_FOR_SEARCH_QUERY_AND_NAME_AND_TAGS.format(
330
- search_query=search_query,
331
- dataset_name=dataset_name,
332
- tags=tags,
333
- ),
334
- max_tokens=1500
335
- ):
336
- generated_text += token
337
- yield generated_text
338
-
339
- logger.debug(f"Generated content:\n{generated_text}")
340
-
341
- def _write_generator_to_queue(queue: Queue, func: Callable[..., Iterable], kwargs: dict) -> None:
342
- """Helper function to write generator output to queue."""
343
- try:
344
- for result in func(**kwargs):
345
- queue.put(result)
346
- except Exception as e:
347
- logger.error(f"Error in generator: {e}")
348
- queue.put(None)
349
-
350
- def iflatmap_unordered(
351
- func: Callable[..., Iterable[T]],
352
- *,
353
- kwargs_iterable: Iterable[dict],
354
- ) -> Iterable[T]:
355
- """Execute generator function with multiple kwargs in parallel."""
356
- queue = Queue()
357
- with ThreadPool() as pool:
358
- async_results = [
359
- pool.apply_async(_write_generator_to_queue, (queue, func, kwargs))
360
- for kwargs in kwargs_iterable
361
- ]
362
- try:
363
- while True:
364
- try:
365
- result = queue.get(timeout=0.05)
366
- if result is not None:
367
- yield result
368
- except Empty:
369
- if all(result.ready() for result in async_results) and queue.empty():
370
- break
371
- finally:
372
- for result in async_results:
373
- try:
374
- result.get(timeout=0.05)
375
- except Exception as e:
376
- logger.error(f"Async result error: {e}")
377
-
378
- def generate_partial_dataset(
379
- title: str,
380
- content: str,
381
- search_query: str,
382
- variant: str,
383
- csv_header: str,
384
- output: list[Dict[str, str]],
385
- indices_to_generate: list[int],
386
- max_tokens=1500
387
- ) -> Iterator[int]:
388
- """Generate partial dataset with specific variants."""
389
- try:
390
- dataset_name, tags = title.strip("# ").split("\ntags:", 1)
391
- dataset_name, tags = dataset_name.strip(), tags.strip()
392
-
393
- messages = [
394
- {
395
- "role": "user",
396
- "content": GENERATE_DATASET_CONTENT_FOR_SEARCH_QUERY_AND_NAME_AND_TAGS.format(
397
- dataset_name=dataset_name,
398
- tags=tags,
399
- search_query=search_query,
400
- )
401
- },
402
- {"role": "assistant", "content": f"{title}\n\n{content}"},
403
- {"role": "user", "content": f"{GENERATE_MORE_ROWS.format(csv_header=csv_header)} {variant}"},
404
- ]
405
-
406
- for response in _generate_dataset_rows(messages, max_tokens, indices_to_generate, output, csv_header):
407
- yield response
408
-
409
- except Exception as e:
410
- logger.error(f"Error generating partial dataset: {e}")
411
- yield 0
412
-
413
- def _generate_dataset_rows(messages: list, max_tokens: int, indices: list, output: list, csv_header: str) -> Iterator[int]:
414
- """Helper function to generate dataset rows."""
415
- for _ in range(3): # Retry logic
416
- try:
417
- return _process_generation(
418
- messages, max_tokens, indices, output, csv_header
419
- )
420
- except requests.exceptions.ConnectionError as e:
421
- logger.warning(f"Connection error: {e}\nRetrying in 1sec")
422
- time.sleep(1)
423
- return iter([])
424
-
425
- def generate_variants(preview_df: pd.DataFrame) -> list[str]:
426
- """Generate variants based on preview dataframe."""
427
- label_candidate_columns = [
428
- column for column in preview_df.columns
429
- if "label" in column.lower()
430
- ]
431
-
432
- if label_candidate_columns:
433
- labels = preview_df[label_candidate_columns[0]].unique()
434
- if len(labels) > 1:
435
- return [
436
- GENERATE_VARIANTS_WITH_RARITY_AND_LABEL.format(
437
- rarity=rarity,
438
- label=label
439
- )
440
- for rarity in RARITIES
441
- for label in labels
442
- ]
443
-
444
- return [
445
- GENERATE_VARIANTS_WITH_RARITY.format(rarity=rarity)
446
- for rarity in LONG_RARITIES
447
- ]
448
-
449
- ###################################
450
- #
451
- # Buttons
452
- #
453
- ###################################
454
-
455
-
456
- def _search_datasets(search_query):
457
- yield {generated_texts_state: []}
458
- yield {
459
- button_group: gr.Group(elem_classes="buttonsGroup insivibleButtonGroup")
460
- for button_group in button_groups[MAX_NB_ITEMS_PER_GENERATION_CALL:]
461
- }
462
- yield {
463
- k: v
464
- for dataset_name_button, tags_button in batched(buttons, 2)
465
- for k, v in {
466
- dataset_name_button: gr.Button("⬜⬜⬜⬜⬜⬜", elem_classes="topButton linear-background"),
467
- tags_button: gr.Button("░░░░, ░░░░, ░░░░", elem_classes="bottomButton linear-background")
468
- }.items()
469
- }
470
- current_item_idx = 0
471
- generated_text = ""
472
- for line in gen_datasets_line_by_line(search_query):
473
- if "I'm sorry" in line or "against Microsoft's use case policy" in line:
474
- raise gr.Error("Error: inappropriate content")
475
- if current_item_idx >= MAX_NB_ITEMS_PER_GENERATION_CALL:
476
- return
477
- if line.strip() and line.strip().split(".", 1)[0].isnumeric():
478
- try:
479
- dataset_name, tags = line.strip().split(".", 1)[1].strip(" )").split(" (", 1)
480
- except ValueError:
481
- dataset_name, tags = line.strip().split(".", 1)[1].strip(" )").split(" ", 1)
482
- dataset_name, tags = dataset_name.strip("()[]* "), tags.strip("()[]* ")
483
- generated_text += line
484
- yield {
485
- buttons[2 * current_item_idx]: gr.Button(dataset_name, elem_classes="topButton"),
486
- buttons[2 * current_item_idx + 1]: gr.Button(tags, elem_classes="bottomButton"),
487
- generated_texts_state: (generated_text,),
488
- }
489
- current_item_idx += 1
490
-
491
-
492
-
493
- @search_button.click(inputs=search_bar, outputs=button_groups + buttons + [generated_texts_state])
494
- def search_dataset_from_search_button(search_query):
495
- yield from _search_datasets(search_query)
496
-
497
-
498
- @search_bar.submit(inputs=search_bar, outputs=button_groups + buttons + [generated_texts_state])
499
- def search_dataset_from_search_bar(search_query):
500
- yield from _search_datasets(search_query)
501
-
502
-
503
- @load_more_datasets.click(inputs=[search_bar, generated_texts_state], outputs=button_groups + buttons + [generated_texts_state])
504
- def search_more_datasets(search_query, generated_texts):
505
- current_item_idx = initial_item_idx = len(generated_texts) * MAX_NB_ITEMS_PER_GENERATION_CALL
506
- yield {
507
- button_group: gr.Group(elem_classes="buttonsGroup")
508
- for button_group in button_groups[len(generated_texts) * MAX_NB_ITEMS_PER_GENERATION_CALL:(len(generated_texts) + 1) * MAX_NB_ITEMS_PER_GENERATION_CALL]
509
- }
510
- generated_text = ""
511
- for line in gen_datasets_line_by_line(search_query, generated_texts=generated_texts):
512
- if "I'm sorry" in line or "against Microsoft's use case policy" in line:
513
- raise gr.Error("Error: inappropriate content")
514
- if current_item_idx - initial_item_idx >= MAX_NB_ITEMS_PER_GENERATION_CALL:
515
- return
516
- if line.strip() and line.strip().split(".", 1)[0].isnumeric():
517
- try:
518
- dataset_name, tags = line.strip().split(".", 1)[1].strip(" )").split(" (", 1)
519
- except ValueError:
520
- dataset_name, tags = line.strip().split(".", 1)[1].strip(" )").split(" ", 1) [0], ""
521
- dataset_name, tags = dataset_name.strip("()[]* "), tags.strip("()[]* ")
522
- generated_text += line
523
- yield {
524
- buttons[2 * current_item_idx]: gr.Button(dataset_name, elem_classes="topButton"),
525
- buttons[2 * current_item_idx + 1]: gr.Button(tags, elem_classes="bottomButton"),
526
- generated_texts_state: (*generated_texts, generated_text),
527
- }
528
- current_item_idx += 1
529
-
530
- def _show_dataset(search_query, dataset_name, tags):
531
- yield {
532
- search_page: gr.Column(visible=False),
533
- dataset_page: gr.Column(visible=True),
534
- dataset_title: f"# {dataset_name}\n\n tags: {tags}",
535
- dataset_share_textbox: gr.Textbox(visible=False),
536
- dataset_dataframe: gr.DataFrame(visible=False),
537
- generate_full_dataset_button: gr.Button(interactive=True),
538
- save_dataset_button: gr.Button(visible=False),
539
- open_dataset_message: gr.Markdown(visible=False)
540
- }
541
- for generated_text in gen_dataset_content(search_query=search_query, dataset_name=dataset_name, tags=tags):
542
- yield {dataset_content: generated_text}
543
-
544
-
545
- show_dataset_inputs = [search_bar, *buttons]
546
- show_dataset_outputs = [search_page, dataset_page, dataset_title, dataset_content, generate_full_dataset_button, dataset_dataframe, save_dataset_button, open_dataset_message, dataset_share_textbox]
547
- scroll_to_top_js = """
548
- function (...args) {
549
- console.log(args);
550
- if ('parentIFrame' in window) {
551
- window.parentIFrame.scrollTo({top: 0, behavior:'smooth'});
552
- } else {
553
- window.scrollTo({ top: 0 });
554
- }
555
- return args;
556
- }
557
- """
558
-
559
- def show_dataset_from_button(search_query, *buttons_values, i):
560
- dataset_name, tags = buttons_values[2 * i : 2 * i + 2]
561
- yield from _show_dataset(search_query, dataset_name, tags)
562
-
563
- for i, (dataset_name_button, tags_button) in enumerate(batched(buttons, 2)):
564
- dataset_name_button.click(partial(show_dataset_from_button, i=i), inputs=show_dataset_inputs, outputs=show_dataset_outputs, js=scroll_to_top_js)
565
- tags_button.click(partial(show_dataset_from_button, i=i), inputs=show_dataset_inputs, outputs=show_dataset_outputs, js=scroll_to_top_js)
566
-
567
-
568
- @back_button.click(outputs=[search_page, dataset_page], js=scroll_to_top_js)
569
- def show_search_page():
570
- return gr.Column(visible=True), gr.Column(visible=False)
571
-
572
-
573
- @generate_full_dataset_button.click(inputs=[dataset_title, dataset_content, search_bar, select_namespace_dropdown, visibility_radio], outputs=[dataset_dataframe, generate_full_dataset_button, save_dataset_button])
574
- def generate_full_dataset(title, content, search_query, namespace, visability):
575
- dataset_name, tags = title.strip("# ").split("\ntags:", 1)
576
- dataset_name, tags = dataset_name.strip(), tags.strip()
577
- csv_header, preview_df = parse_preview_df(content)
578
- # Remove dummy "id" columns
579
- for column_name, values in preview_df.to_dict(orient="series").items():
580
- try:
581
- if [int(v) for v in values] == list(range(len(preview_df))):
582
- preview_df = preview_df.drop(columns=column_name)
583
- if [int(v) for v in values] == list(range(1, len(preview_df) + 1)):
584
- preview_df = preview_df.drop(columns=column_name)
585
- except Exception:
586
- pass
587
- columns = list(preview_df)
588
- output: list[Optional[dict]] = [None] * NUM_ROWS
589
- output[:len(preview_df)] = [{"idx": i, **x} for i, x in enumerate(preview_df.to_dict(orient="records"))]
590
- yield {
591
- dataset_dataframe: gr.DataFrame(pd.DataFrame([{"idx": i, **x} for i, x in enumerate(output) if x]), visible=True),
592
- generate_full_dataset_button: gr.Button(interactive=False),
593
- save_dataset_button: gr.Button(f"💾 Save Dataset {namespace}/{dataset_name}" + (" (private)" if visability != "public" else ""), visible=True, interactive=False)
594
- }
595
- kwargs_iterable = [
596
- {
597
- "title": title,
598
- "content": content,
599
- "search_query": search_query,
600
- "variant": variant,
601
- "csv_header": csv_header,
602
- "output": output,
603
- "indices_to_generate": list(range(len(preview_df) + i, NUM_ROWS, NUM_VARIANTS)),
604
- }
605
- for i, variant in enumerate(islice(generate_variants(preview_df), NUM_VARIANTS))
606
- ]
607
- for _ in iflatmap_unordered(generate_partial_dataset, kwargs_iterable=kwargs_iterable):
608
- yield {dataset_dataframe: pd.DataFrame([{"idx": i, **{column_name: x.get(column_name) for column_name in columns}} for i, x in enumerate(output) if x])}
609
- yield {save_dataset_button: gr.Button(interactive=True)}
610
- print(f"Generated {dataset_name}!")
611
-
612
-
613
- @save_dataset_button.click(inputs=[dataset_title, dataset_content, search_bar, dataset_dataframe, select_namespace_dropdown, visibility_radio], outputs=[save_dataset_button, open_dataset_message])
614
- def save_dataset(title: str, content: str, search_query: str, df: pd.DataFrame, namespace: str, visability: str, oauth_token: Optional[gr.OAuthToken]):
615
- dataset_name, tags = title.strip("# ").split("\ntags:", 1)
616
- dataset_name, tags = dataset_name.strip(), tags.strip()
617
- token = oauth_token.token if oauth_token else save_dataset_hf_token
618
- repo_id = f"{namespace}/{dataset_name}"
619
- dataset_url = f"{URL}?q={search_query.replace(' ', '+')}&dataset={dataset_name.replace(' ', '+')}&tags={tags.replace(' ', '+')}"
620
- gr.Info("Saving dataset...")
621
- yield {save_dataset_button: gr.Button(interactive=False)}
622
- create_repo(repo_id=repo_id, repo_type="dataset", private=visability!="public", exist_ok=True, token=token)
623
- df.to_csv(f"hf://datasets/{repo_id}/data.csv", storage_options={"token": token}, index=False)
624
- DatasetCard(DATASET_CARD_CONTENT.format(title=title, content=content, url=URL, dataset_url=dataset_url, model_id=model_id, search_query=search_query)).push_to_hub(repo_id=repo_id, repo_type="dataset", token=token)
625
- gr.Info(f"✅ Dataset saved at {repo_id}")
626
- additional_message = "PS: You can also save datasets under your account in the Settings ;)"
627
- yield {open_dataset_message: gr.Markdown(f"# 🎉 Yay ! Your dataset has been saved to [{repo_id}](https://huggingface.co/datasets/{repo_id}) !\n\nDataset link: [https://huggingface.co/datasets/{repo_id}](https://huggingface.co/datasets/{repo_id})\n\n{additional_message}", visible=True)}
628
- print(f"Saved {dataset_name}!")
629
-
630
-
631
- @dataset_share_button.click(inputs=[dataset_title, search_bar], outputs=[dataset_share_textbox])
632
- def show_dataset_url(title, search_query):
633
- dataset_name, tags = title.strip("# ").split("\ntags:", 1)
634
- dataset_name, tags = dataset_name.strip(), tags.strip()
635
- return gr.Textbox(
636
- f"{URL}?q={search_query.replace(' ', '+')}&dataset={dataset_name.replace(' ', '+')}&tags={tags.replace(' ', '+')}",
637
- visible=True,
638
- )
639
-
640
- @demo.load(outputs=[dataset_title, dataset_content, dataset_dataframe, search_bar])
641
- def load_app(request: gr.Request, oauth_token: Optional[gr.OAuthToken]):
642
- if oauth_token:
643
- user_info = whoami(oauth_token.token)
644
- yield {
645
- select_namespace_dropdown: gr.Dropdown(
646
- choices=[user_info["name"]] + [org_info["name"] for org_info in user_info["orgs"]],
647
- value=user_info["name"],
648
- visible=True,
649
- ),
650
- visibility_radio: gr.Radio(interactive=True),
651
- }
652
- query_params = dict(request.query_params)
653
- if "dataset" in query_params:
654
- yield from _show_dataset(
655
- search_query=query_params.get("q", query_params["dataset"]),
656
- dataset_name=query_params["dataset"],
657
- tags=query_params.get("tags", "")
658
- )
659
- elif "q" in query_params:
660
- yield {search_bar: query_params["q"]}
661
- yield from _search_datasets(query_params["q"])
662
- else:
663
- # Default behavior
664
- yield {file_uploader: gr.File(visible=True)}
665
-
666
- @demo.upload(
667
- inputs=[search_bar, dataset_title, dataset_content, dataset_dataframe, select_namespace_dropdown, visibility_radio],
668
- outputs=[save_dataset_button, open_dataset_message]
669
- )
670
- def upload_dataset(search_query, dataset_name, dataset_content, df, namespace, visibility):
671
- # Parse dataset name and tags
672
- dataset_name, tags = dataset_name.strip("# ").split("\ntags:", 1)
673
- dataset_name, tags = dataset_name.strip(), tags.strip()
674
-
675
- #Create local directory structure
676
- base_dir = os.path.join(os.getcwd(), "datasets")
677
- dataset_dir = os.path.join(base_dir, dataset_name)
678
- os.makedirs(dataset_dir, exist_ok=True)
679
-
680
- # Parse and clean preview dataframe
681
- csv_header, preview_df = parse_preview_df(dataset_content)
682
-
683
- # Remove dummy "id" columns
684
- for column_name, values in preview_df.to_dict(orient="series").items():
685
- try:
686
- if [int(v) for v in values] == list(range(len(preview_df))):
687
- preview_df = preview_df.drop(columns=column_name)
688
- if [int(v) for v in values] == list(range(1, len(preview_df) + 1)):
689
- preview_df = preview_df.drop(columns=column_name)
690
- except Exception:
691
- pass
692
-
693
- columns = list(preview_df)
694
- output: list[Optional[dict]] = [None] * NUM_ROWS
695
- output[:len(preview_df)] = [{"idx": i, **x} for i, x in enumerate(preview_df.to_dict(orient="records"))]
696
-
697
- # Update UI to show upload progress
698
- yield {
699
- save_dataset_button: gr.Button(
700
- f"💾 Save Dataset {namespace}/{dataset_name}" +
701
- (" (private)" if visibility != "public" else ""),
702
- interactive=False
703
- ),
704
- open_dataset_message: gr.Markdown(f"Uploading dataset {dataset_name}...")
705
- }
706
-
707
- try:
708
- # Get authentication token
709
- token = oauth_token.token if oauth_token else save_dataset_hf_token
710
- repo_id = f"{namespace}/{dataset_name}"
711
-
712
- # Save files locally first
713
- local_csv_path = os.path.join(dataset_dir, "data.csv")
714
- local_readme_path = os.path.join(dataset_dir, "README.md")
715
-
716
- # Save CSV locally
717
- df.to_csv(local_csv_path, index=False)
718
-
719
- # Create dataset card content
720
- dataset_url = f"{URL}?q={search_query.replace(' ', '+')}&dataset={dataset_name.replace(' ', '+')}&tags={tags.replace(' ', '+')}"
721
- dataset_card = DatasetCard(
722
- DATASET_CARD_CONTENT.format(
723
- title=title,
724
- content=content,
725
- url=URL,
726
- dataset_url=dataset_url,
727
- model_id=model_id,
728
- search_query=search_query
729
- )
730
- )
731
-
732
- # Save README locally
733
- with open(local_readme_path, 'w', encoding='utf-8') as f:
734
- f.write(str(dataset_card))
735
-
736
- # Create and upload to Hub
737
- create_repo(
738
- repo_id=repo_id,
739
- repo_type="dataset",
740
- private=visibility != "public",
741
- exist_ok=True,
742
- token=token
743
- )
744
-
745
- # Upload files to Hub
746
- df.to_csv(
747
- f"hf://datasets/{repo_id}/data.csv",
748
- storage_options={"token": token},
749
- index=False
750
- )
751
- dataset_card.push_to_hub(
752
- repo_id=repo_id,
753
- repo_type="dataset",
754
- token=token
755
- )
756
-
757
- # Show success message
758
- gr.Info(f"✅ Dataset saved at {repo_id}")
759
- additional_message = "PS: You can also save datasets under your account in the Settings ;)"
760
- yield {
761
- open_dataset_message: gr.Markdown(
762
- f"# 🎉 Yay ! Your dataset has been saved to [{repo_id}](https://huggingface.co/datasets/{repo_id}) !\n\n"
763
- f"Dataset link: [https://huggingface.co/datasets/{repo_id}](https://huggingface.co/datasets/{repo_id})\n\n"
764
- f"{additional_message}",
765
- visible=True
766
- )
767
- }
768
- print(f"Saved {dataset_name}!")
769
-
770
- except Exception as e:
771
- print(f"Error saving dataset: {e}")
772
- yield {
773
- open_dataset_message: gr.Markdown(
774
- f"❌ Error saving dataset: {str(e)}",
775
- visible=True
776
- )
777
- }
778
 
779
- if __name__ == "__main__":
780
- demo.launch()
 
17
  client = InferenceClient(model_id)
18
  save_dataset_hf_token = os.environ.get("SAVE_DATASET_HF_TOKEN")
19
 
20
+ MAX_TOTAL_NB_ITEMS = 100
21
  MAX_NB_ITEMS_PER_GENERATION_CALL = 10
22
  NUM_ROWS = 100
23
  NUM_VARIANTS = 10
 
25
  URL = "https://huggingface.co/spaces/infinite-dataset-hub/infinite-dataset-hub"
26
 
27
  GENERATE_DATASET_NAMES_FOR_SEARCH_QUERY = (
28
+ "A Machine Learning Practitioner is looking for a dataset that matches '{search_query}'. "
29
  f"Generate a list of {MAX_NB_ITEMS_PER_GENERATION_CALL} names of quality datasets that don't exist but sound plausible and would "
30
  "be helpful. Feel free to reuse words from the query '{search_query}' to name the datasets. "
31
  "Every dataset should be about '{search_query}' and have descriptive tags/keywords including the ML task name associated with the dataset (classification, regression, anomaly detection, etc.). Use the following format:\n1. DatasetName1 (tag1, tag2, tag3)\n1. DatasetName2 (tag1, tag2, tag3)"
 
39
  "Reply using a short description of the dataset with title **Dataset Description:** followed by the CSV content in a code block and with title **CSV Content Preview:**."
40
  )
41
 
 
 
42
  GENERATE_MORE_ROWS = "Can you give me 10 additional samples in CSV format as well? Use the same CSV header '{csv_header}'."
43
  GENERATE_VARIANTS_WITH_RARITY_AND_LABEL = "Focus on generating samples for the label '{label}' and ideally generate {rarity} samples."
44
  GENERATE_VARIANTS_WITH_RARITY = "Focus on generating {rarity} samples."
 
69
  9. HealthVitalSigns (anomaly detection, biometrics, prediction)
70
  10. GameStockPredict (classification, finance, sports contingency)
71
  """
72
+
73
  default_output = landing_page_datasets_generated_text.strip().split("\n")
74
+ assert len(default_output) == MAX_NB_ITEMS_PER_GENERATION_CALL
75
 
76
  DATASET_CARD_CONTENT = """
77
  ---
 
171
  """
172
 
173
  with gr.Blocks(css=css) as demo:
 
174
  generated_texts_state = gr.State((landing_page_datasets_generated_text,))
175
+
176
  with gr.Column() as search_page:
177
  with gr.Row():
178
  with gr.Column(scale=10):
179
  gr.Markdown(
180
+ "# 🤗 Infinite Dataset Hub ♾️\n\n"
181
  "An endless catalog of datasets, created just for you by an AI model.\n\n"
182
  )
183
  with gr.Row():
184
  search_bar = gr.Textbox(
185
+ max_lines=1,
186
+ placeholder="Search datasets, get infinite results",
187
+ show_label=False,
188
+ container=False,
189
  scale=9
190
  )
191
  search_button = gr.Button("🔍", variant="primary", scale=1)
192
+
 
193
  button_groups: list[gr.Group] = []
194
  buttons: list[gr.Button] = []
195
+
 
196
  for i in range(MAX_TOTAL_NB_ITEMS):
197
  if i < len(default_output):
198
  line = default_output[i]
 
205
  group_classes = "buttonsGroup insivibleButtonGroup"
206
  dataset_name_classes = "topButton linear-background"
207
  tags_classes = "bottomButton linear-background"
208
+
209
  with gr.Group(elem_classes=group_classes) as button_group:
210
  button_groups.append(button_group)
211
  buttons.append(gr.Button(dataset_name, elem_classes=dataset_name_classes))
212
  buttons.append(gr.Button(tags, elem_classes=tags_classes))
213
 
214
+ load_more_datasets = gr.Button("Load more datasets")
215
  gr.Markdown(f"_powered by [{model_id}](https://huggingface.co/{model_id})_")
216
+
217
+ with gr.Column(scale=4, min_width="200px"):
218
+ with gr.Accordion("Settings", open=False, elem_classes="settings"):
219
+ gr.Markdown("Save datasets to your account")
220
+ gr.LoginButton()
221
+ select_namespace_dropdown = gr.Dropdown(
222
+ choices=[NAMESPACE],
223
+ value=NAMESPACE,
224
+ label="Select user or organization",
225
+ visible=False
226
+ )
 
 
227
  gr.Markdown("Save datasets as public or private datasets")
228
  visibility_radio = gr.Radio(
229
+ ["public", "private"],
230
+ value="public",
231
+ container=False,
232
  interactive=False
233
  )
 
234
  with gr.Column(visible=False) as dataset_page:
235
  gr.Markdown(
236
  "# 🤗 Infinite Dataset Hub ♾️\n\n"
 
247
  dataset_share_textbox = gr.Textbox(visible=False, show_copy_button=True, label="Copy this URL:", interactive=False, show_label=True)
248
  back_button = gr.Button("< Back", size="sm")
249
 
250
+ # Define the remaining functions and event handlers...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
 
252
+ demo.launch()