acecalisto3 commited on
Commit
f48943e
·
verified ·
1 Parent(s): 0759604

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -213
app.py CHANGED
@@ -6,37 +6,30 @@ from itertools import islice
6
  from functools import partial
7
  from multiprocessing.pool import ThreadPool
8
  from queue import Queue, Empty
9
- from typing import Callable, Iterable, Iterator, Optional, TypeVar, List, Dict
10
- import datetime
11
 
12
  import gradio as gr
13
  import pandas as pd
14
  import requests.exceptions
15
  from huggingface_hub import InferenceClient, create_repo, whoami, DatasetCard
16
 
 
17
  model_id = "microsoft/Phi-3-mini-4k-instruct"
18
  client = InferenceClient(model_id)
19
  save_dataset_hf_token = os.environ.get("SAVE_DATASET_HF_TOKEN")
20
 
21
- AUTORUN_INTERVAL = 2 # Seconds between dataset generations
22
- MAX_AUTORUN_DATASETS = 1000 # Safety limit for infinite mode
23
  MAX_TOTAL_NB_ITEMS = 100 # almost infinite, don't judge me (actually it's because gradio needs a fixed number of components)
24
  MAX_NB_ITEMS_PER_GENERATION_CALL = 10
25
  NUM_ROWS = 100
26
- MAX_QUEUE_SIZE = 100 # Maximum number of concurrent users
27
  NUM_VARIANTS = 10
28
  NAMESPACE = "infinite-dataset-hub"
29
  URL = "https://huggingface.co/spaces/infinite-dataset-hub/infinite-dataset-hub"
30
- # Add these after existing state variables
31
- autorun_active = gr.State(False)
32
- accumulated_datasets = gr.State(pd.DataFrame())
33
- current_processing = gr.State(set())
34
 
35
  GENERATE_DATASET_NAMES_FOR_SEARCH_QUERY = (
36
- "A Machine Learning Practioner is looking for a dataset that matches '{search_query}'. "
37
- f"Generate a list of {MAX_NB_ITEMS_PER_GENERATION_CALL} names of quality datasets that don't exist but sound plausible and would "
38
- "be helpful. Feel free to reuse words from the query '{search_query}' to name the datasets. "
39
- "Every dataset should be about '{search_query}' and have descriptive tags/keywords including the ML task name associated with the dataset (classification, regression, anomaly detection, etc.). Use the following format:\n1. DatasetName1 (tag1, tag2, tag3)\n1. DatasetName2 (tag1, tag2, tag3)"
40
  )
41
 
42
  GENERATE_DATASET_CONTENT_FOR_SEARCH_QUERY_AND_NAME_AND_TAGS = (
@@ -56,7 +49,7 @@ LONG_RARITIES = [
56
  "expected",
57
  "common",
58
  "regular",
59
- "unexpected but useful", # <-- Added missing comma here.
60
  "original but useful",
61
  "specific but not far-fetched",
62
  "uncommon but still plausible",
@@ -86,34 +79,27 @@ tags:
86
  - infinite-dataset-hub
87
  - synthetic
88
  ---
 
89
  {title}
 
90
  _Note: This is an AI-generated dataset so its content may be inaccurate or false_
 
91
  {content}
 
92
  **Source of the data:**
 
93
  The dataset was generated using the [Infinite Dataset Hub]({url}) and {model_id} using the query '{search_query}':
 
94
  - **Dataset Generation Page**: {dataset_url}
95
  - **Model**: https://huggingface.co/{model_id}
96
  - **More Datasets**: https://huggingface.co/datasets?other=infinite-dataset-hub
97
  """
98
 
99
  css = """
100
- .autorun-section {
101
- border: 1px solid var(--border-color-primary);
102
- border-radius: 8px;
103
- padding: 1rem;
104
- margin-top: 1rem;
105
- }
106
- .compile-options {
107
- margin-top: 1rem;
108
- }
109
- .download-prompt {
110
- color: var(--color-accent);
111
- font-weight: bold;
112
- margin-top: 1rem;
113
- }
114
  a {
115
  color: var(--body-text-color);
116
  }
 
117
  .datasetButton {
118
  justify-content: start;
119
  justify-content: left;
@@ -163,6 +149,7 @@ a {
163
  .insivibleButtonGroup {
164
  display: none;
165
  }
 
166
  @keyframes placeHolderShimmer{
167
  0%{
168
  background-position: -468px 0
@@ -190,6 +177,7 @@ a {
190
  }
191
  """
192
 
 
193
  with gr.Blocks(css=css) as demo:
194
  generated_texts_state = gr.State((landing_page_datasets_generated_text,))
195
  with gr.Column() as search_page:
@@ -221,7 +209,7 @@ with gr.Blocks(css=css) as demo:
221
  buttons.append(gr.Button(dataset_name, elem_classes=dataset_name_classes))
222
  buttons.append(gr.Button(tags, elem_classes=tags_classes))
223
 
224
- load_more_datasets = gr.Button("Load more datasets") # TODO: disable when reaching end of page
225
  gr.Markdown(f"_powered by [{model_id}](https://huggingface.co/{model_id})_")
226
  with gr.Column(scale=4, min_width="200px"):
227
  with gr.Accordion("Settings", open=False, elem_classes="settings"):
@@ -246,25 +234,6 @@ with gr.Blocks(css=css) as demo:
246
  dataset_share_textbox = gr.Textbox(visible=False, show_copy_button=True, label="Copy this URL:", interactive=False, show_label=True)
247
  back_button = gr.Button("< Back", size="sm")
248
 
249
- with gr.Column(elem_classes="autorun-section") as autorun_section:
250
- with gr.Row():
251
- autorun_toggle = gr.Checkbox(label="AutoRun Mode", interactive=True)
252
- autorun_status = gr.Markdown("**Status:** Inactive", elem_classes="status")
253
-
254
- with gr.Row():
255
- compile_mode = gr.Radio(
256
- ["Combine All", "Keep Separate"],
257
- label="Compilation Mode",
258
- value="Combine All"
259
- )
260
- processing_options = gr.CheckboxGroup(
261
- ["Clean Data", "Chunk Data", "Summarize Data"],
262
- label="Processing Options"
263
- )
264
-
265
- with gr.Row():
266
- download_btn = gr.DownloadButton("Download Dataset", visible=False)
267
- stop_btn = gr.Button("Stop & Save", variant="stop", visible=False)
268
  ###################################
269
  #
270
  # Utils
@@ -278,6 +247,7 @@ with gr.Blocks(css=css) as demo:
278
  while batch := list(islice(it, n)):
279
  yield batch
280
 
 
281
  def stream_reponse(msg: str, generated_texts: tuple[str] = (), max_tokens=500) -> Iterator[str]:
282
  messages = [
283
  {"role": "user", "content": msg}
@@ -300,59 +270,11 @@ with gr.Blocks(css=css) as demo:
300
  ):
301
  yield message.choices[0].delta.content
302
  except requests.exceptions.ConnectionError as e:
303
- print(f"{e}\n\nRetrying in 1sec")
304
  time.sleep(1)
305
  continue
306
  break
307
 
308
- def generate_single_dataset(search_query: str) -> pd.DataFrame:
309
- """Generate one complete dataset from search query to parsed DataFrame"""
310
- # Generate dataset names
311
- dataset_lines = []
312
- for line in gen_datasets_line_by_line(search_query):
313
- dataset_lines.append(line)
314
- if len(dataset_lines) >= MAX_NB_ITEMS_PER_GENERATION_CALL:
315
- break
316
-
317
- # Process first valid dataset
318
- for line in dataset_lines:
319
- if line.strip() and line.strip().split(".", 1)[0].isnumeric():
320
- try:
321
- dataset_name, tags = line.strip().split(".", 1)[1].strip(" )").split(" (", 1)
322
- break
323
- except ValueError:
324
- continue
325
-
326
- # Generate dataset content
327
- content = ""
328
- for token in gen_dataset_content(search_query, dataset_name, tags):
329
- content += token
330
-
331
- # Parse to DataFrame
332
- _, preview_df = parse_preview_df(content)
333
- return preview_df
334
-
335
- def process_dataset(df: pd.DataFrame, options: List[str]) -> pd.DataFrame:
336
- """Apply processing options to dataset"""
337
- # Clean
338
- if 'Clean Data' in options:
339
- df = df.dropna().drop_duplicates()
340
-
341
- # Chunk
342
- if 'Chunk Data' in options:
343
- if len(df) > 10:
344
- df = df.sample(frac=0.5) # Simple chunking example
345
-
346
- # Summarize
347
- if 'Summarize Data' in options:
348
- summary = pd.DataFrame({
349
- 'columns': df.columns,
350
- 'dtypes': df.dtypes.values,
351
- 'non_null_count': df.count().values
352
- })
353
- return summary
354
-
355
- return df
356
 
357
  def gen_datasets_line_by_line(search_query: str, generated_texts: tuple[str] = ()) -> Iterator[str]:
358
  search_query = search_query or ""
@@ -372,6 +294,7 @@ with gr.Blocks(css=css) as demo:
372
  generated_text += current_line
373
  print("-----\n\n" + generated_text)
374
 
 
375
  def gen_dataset_content(search_query: str, dataset_name: str, tags: str) -> Iterator[str]:
376
  search_query = search_query or ""
377
  search_query = search_query[:1000] if search_query.strip() else ""
@@ -385,11 +308,13 @@ with gr.Blocks(css=css) as demo:
385
  yield generated_text
386
  print("-----\n\n" + generated_text)
387
 
 
388
  def _write_generator_to_queue(queue: Queue, func: Callable[..., Iterable], kwargs: dict) -> None:
389
  for i, result in enumerate(func(**kwargs)):
390
  queue.put(result)
391
  return None
392
 
 
393
  def iflatmap_unordered(
394
  func: Callable[..., Iterable[T]],
395
  *,
@@ -411,6 +336,7 @@ with gr.Blocks(css=css) as demo:
411
  # we get the result in case there's an error to raise
412
  [async_result.get(timeout=0.05) for async_result in async_results]
413
 
 
414
  def generate_partial_dataset(title: str, content: str, search_query: str, variant: str, csv_header: str, output: list[dict[str, str]], indices_to_generate: list[int], max_tokens=1500) -> Iterator[int]:
415
  dataset_name, tags = title.strip("# ").split("\ntags:", 1)
416
  dataset_name, tags = dataset_name.strip(), tags.strip()
@@ -458,10 +384,14 @@ with gr.Blocks(css=css) as demo:
458
  pass
459
  current_line = ""
460
  except requests.exceptions.ConnectionError as e:
461
- print(f"{e}\n\nRetrying in 1sec")
462
  time.sleep(1)
463
  continue
464
  break
 
 
 
 
465
 
466
  def generate_variants(preview_df: pd.DataFrame):
467
  label_candidate_columns = [column for column in preview_df.columns if "label" in column.lower()]
@@ -478,6 +408,7 @@ with gr.Blocks(css=css) as demo:
478
  for rarity in LONG_RARITIES
479
  ]
480
 
 
481
  def parse_preview_df(content: str) -> tuple[str, pd.DataFrame]:
482
  _in_csv = False
483
  csv = "\n".join(
@@ -489,6 +420,7 @@ with gr.Blocks(css=css) as demo:
489
  raise gr.Error("Failed to parse CSV Preview")
490
  return csv.split("\n")[0], parse_csv_df(csv)
491
 
 
492
  def parse_csv_df(csv: str, csv_header: Optional[str] = None) -> pd.DataFrame:
493
  # Fix generation mistake when providing a list that is not in quotes
494
  for match in re.finditer(r'''(?!")\[(["'][\w ]+["'][, ]*)+\](?!")''', csv):
@@ -501,12 +433,14 @@ with gr.Blocks(css=css) as demo:
501
  df = pd.read_csv(io.StringIO(csv), skipinitialspace=True)
502
  return df
503
 
 
504
  ###################################
505
  #
506
  # Buttons
507
  #
508
  ###################################
509
 
 
510
  def _search_datasets(search_query):
511
  yield {generated_texts_state: []}
512
  yield {
@@ -542,14 +476,17 @@ with gr.Blocks(css=css) as demo:
542
  }
543
  current_item_idx += 1
544
 
 
545
  @search_button.click(inputs=search_bar, outputs=button_groups + buttons + [generated_texts_state])
546
  def search_dataset_from_search_button(search_query):
547
  yield from _search_datasets(search_query)
548
 
 
549
  @search_bar.submit(inputs=search_bar, outputs=button_groups + buttons + [generated_texts_state])
550
  def search_dataset_from_search_bar(search_query):
551
  yield from _search_datasets(search_query)
552
 
 
553
  @load_more_datasets.click(inputs=[search_bar, generated_texts_state], outputs=button_groups + buttons + [generated_texts_state])
554
  def search_more_datasets(search_query, generated_texts):
555
  current_item_idx = initial_item_idx = len(generated_texts) * MAX_NB_ITEMS_PER_GENERATION_CALL
@@ -567,7 +504,7 @@ with gr.Blocks(css=css) as demo:
567
  try:
568
  dataset_name, tags = line.strip().split(".", 1)[1].strip(" )").split(" (", 1)
569
  except ValueError:
570
- dataset_name, tags = line.strip().split(".", 1)[1].strip(" )").split(" ", 1)[0], ""
571
  dataset_name, tags = dataset_name.strip("()[]* "), tags.strip("()[]* ")
572
  generated_text += line
573
  yield {
@@ -577,66 +514,21 @@ with gr.Blocks(css=css) as demo:
577
  }
578
  current_item_idx += 1
579
 
580
- def toggle_autorun(active: bool, current_df: pd.DataFrame) -> dict:
581
- """Toggle autorun state and UI elements"""
582
- new_state = not active
583
- updates = {
584
- autorun_toggle: gr.Checkbox.update(value=new_state),
585
- autorun_status: gr.Markdown.update(value=f"**Status:** {'Active' if new_state else 'Inactive'}"),
586
- stop_btn: gr.Button.update(visible=new_state),
587
- download_btn: gr.DownloadButton.update(visible=not new_state),
588
- accumulated_datasets: current_df # Maintain current state
589
- }
590
- if new_state: # Reset when starting new run
591
- updates[accumulated_datasets] = pd.DataFrame()
592
- return updates
593
-
594
- def autorun_iteration(
595
- search_query: str,
596
- current_df: pd.DataFrame,
597
- compile_mode: str,
598
- process_opts: List[str]
599
- ) -> pd.DataFrame:
600
- """Single iteration of autorun dataset generation"""
601
- try:
602
- new_data = generate_single_dataset(search_query)
603
- processed = process_dataset(new_data, process_opts)
604
-
605
- if compile_mode == "Combine All":
606
- combined = pd.concat([current_df, processed], ignore_index=True)
607
- # Return full dataset but only show last 50
608
- return combined
609
- else:
610
- return pd.concat([current_df, processed], ignore_index=True)
611
- except Exception as e:
612
- print(f"Error in autorun iteration: {e}")
613
- return current_df
614
-
615
- def create_download_file(current_df: pd.DataFrame) -> str:
616
- """Prepare dataset for download; returns the filename"""
617
- timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
618
- filename = f"autorun-dataset-{timestamp}.csv"
619
- current_df.to_csv(filename, index=False)
620
- return filename
621
-
622
- # Helper function to update the displayed dataframe (showing last 50 rows)
623
- def update_display(df: pd.DataFrame) -> pd.DataFrame:
624
- return df.tail(50)
625
-
626
  def _show_dataset(search_query, dataset_name, tags):
627
  yield {
628
- search_page: gr.Column.update(visible=False),
629
- dataset_page: gr.Column.update(visible=True),
630
  dataset_title: f"# {dataset_name}\n\n tags: {tags}",
631
- dataset_share_textbox: gr.Textbox.update(visible=False),
632
- dataset_dataframe: gr.DataFrame.update(visible=False),
633
- generate_full_dataset_button: gr.Button.update(interactive=True),
634
- save_dataset_button: gr.Button.update(visible=False),
635
- open_dataset_message: gr.Markdown.update(visible=False)
636
  }
637
  for generated_text in gen_dataset_content(search_query=search_query, dataset_name=dataset_name, tags=tags):
638
  yield {dataset_content: generated_text}
639
 
 
640
  show_dataset_inputs = [search_bar, *buttons]
641
  show_dataset_outputs = [search_page, dataset_page, dataset_title, dataset_content, generate_full_dataset_button, dataset_dataframe, save_dataset_button, open_dataset_message, dataset_share_textbox]
642
  scroll_to_top_js = """
@@ -651,33 +543,6 @@ with gr.Blocks(css=css) as demo:
651
  }
652
  """
653
 
654
- # Function to update UI when stopping autorun
655
- def stop_autorun():
656
- return (
657
- gr.Checkbox.update(value=False),
658
- gr.Markdown.update(value="**Status:** Inactive"),
659
- gr.Button.update(visible=False),
660
- gr.DownloadButton.update(visible=True)
661
- )
662
-
663
- autorun_toggle.change(
664
- toggle_autorun,
665
- inputs=[autorun_active, accumulated_datasets],
666
- outputs=[autorun_toggle, autorun_status, stop_btn, download_btn, accumulated_datasets]
667
- )
668
-
669
- stop_btn.click(
670
- stop_autorun,
671
- inputs=None,
672
- outputs=[autorun_toggle, autorun_status, stop_btn, download_btn]
673
- )
674
-
675
- download_btn.click(
676
- create_download_file,
677
- inputs=accumulated_datasets,
678
- outputs=download_btn
679
- )
680
-
681
  def show_dataset_from_button(search_query, *buttons_values, i):
682
  dataset_name, tags = buttons_values[2 * i : 2 * i + 2]
683
  yield from _show_dataset(search_query, dataset_name, tags)
@@ -686,10 +551,12 @@ with gr.Blocks(css=css) as demo:
686
  dataset_name_button.click(partial(show_dataset_from_button, i=i), inputs=show_dataset_inputs, outputs=show_dataset_outputs, js=scroll_to_top_js)
687
  tags_button.click(partial(show_dataset_from_button, i=i), inputs=show_dataset_inputs, outputs=show_dataset_outputs, js=scroll_to_top_js)
688
 
 
689
  @back_button.click(outputs=[search_page, dataset_page], js=scroll_to_top_js)
690
  def show_search_page():
691
  return gr.Column(visible=True), gr.Column(visible=False)
692
 
 
693
  @generate_full_dataset_button.click(inputs=[dataset_title, dataset_content, search_bar, select_namespace_dropdown, visibility_radio], outputs=[dataset_dataframe, generate_full_dataset_button, save_dataset_button])
694
  def generate_full_dataset(title, content, search_query, namespace, visability):
695
  dataset_name, tags = title.strip("# ").split("\ntags:", 1)
@@ -709,8 +576,8 @@ with gr.Blocks(css=css) as demo:
709
  output[:len(preview_df)] = [{"idx": i, **x} for i, x in enumerate(preview_df.to_dict(orient="records"))]
710
  yield {
711
  dataset_dataframe: gr.DataFrame(pd.DataFrame([{"idx": i, **x} for i, x in enumerate(output) if x]), visible=True),
712
- generate_full_dataset_button: gr.Button.update(interactive=False),
713
- save_dataset_button: gr.Button.update(label=f"💾 Save Dataset {namespace}/{dataset_name}" + (" (private)" if visability != "public" else ""), visible=True, interactive=False)
714
  }
715
  kwargs_iterable = [
716
  {
@@ -726,9 +593,10 @@ with gr.Blocks(css=css) as demo:
726
  ]
727
  for _ in iflatmap_unordered(generate_partial_dataset, kwargs_iterable=kwargs_iterable):
728
  yield {dataset_dataframe: pd.DataFrame([{"idx": i, **{column_name: x.get(column_name) for column_name in columns}} for i, x in enumerate(output) if x])}
729
- yield {save_dataset_button: gr.Button.update(interactive=True)}
730
  print(f"Generated {dataset_name}!")
731
 
 
732
  @save_dataset_button.click(inputs=[dataset_title, dataset_content, search_bar, dataset_dataframe, select_namespace_dropdown, visibility_radio], outputs=[save_dataset_button, open_dataset_message])
733
  def save_dataset(title: str, content: str, search_query: str, df: pd.DataFrame, namespace: str, visability: str, oauth_token: Optional[gr.OAuthToken]):
734
  dataset_name, tags = title.strip("# ").split("\ntags:", 1)
@@ -737,32 +605,36 @@ with gr.Blocks(css=css) as demo:
737
  repo_id = f"{namespace}/{dataset_name}"
738
  dataset_url = f"{URL}?q={search_query.replace(' ', '+')}&dataset={dataset_name.replace(' ', '+')}&tags={tags.replace(' ', '+')}"
739
  gr.Info("Saving dataset...")
740
- yield {save_dataset_button: gr.Button.update(interactive=False)}
741
  create_repo(repo_id=repo_id, repo_type="dataset", private=visability!="public", exist_ok=True, token=token)
742
  df.to_csv(f"hf://datasets/{repo_id}/data.csv", storage_options={"token": token}, index=False)
743
  DatasetCard(DATASET_CARD_CONTENT.format(title=title, content=content, url=URL, dataset_url=dataset_url, model_id=model_id, search_query=search_query)).push_to_hub(repo_id=repo_id, repo_type="dataset", token=token)
744
  gr.Info(f"✅ Dataset saved at {repo_id}")
745
  additional_message = "PS: You can also save datasets under your account in the Settings ;)"
746
- yield {open_dataset_message: gr.Markdown.update(value=f"# 🎉 Yay ! Your dataset has been saved to [{repo_id}](https://huggingface.co/datasets/{repo_id}) !\n\nDataset link: [https://huggingface.co/datasets/{repo_id}](https://huggingface.co/datasets/{repo_id})\n\n{additional_message}", visible=True)}
747
  print(f"Saved {dataset_name}!")
748
 
 
749
  @dataset_share_button.click(inputs=[dataset_title, search_bar], outputs=[dataset_share_textbox])
750
  def show_dataset_url(title, search_query):
751
  dataset_name, tags = title.strip("# ").split("\ntags:", 1)
752
  dataset_name, tags = dataset_name.strip(), tags.strip()
753
- return gr.Textbox.update(value=f"{URL}?q={search_query.replace(' ', '+')}&dataset={dataset_name.replace(' ', '+')}&tags={tags.replace(' ', '+')}", visible=True)
 
 
 
754
 
755
  @demo.load(outputs=show_dataset_outputs + button_groups + buttons + [generated_texts_state] + [select_namespace_dropdown, visibility_radio])
756
  def load_app(request: gr.Request, oauth_token: Optional[gr.OAuthToken]):
757
  if oauth_token:
758
  user_info = whoami(oauth_token.token)
759
  yield {
760
- select_namespace_dropdown: gr.Dropdown.update(
761
  choices=[user_info["name"]] + [org_info["name"] for org_info in user_info["orgs"]],
762
  value=user_info["name"],
763
  visible=True,
764
  ),
765
- visibility_radio: gr.Radio.update(interactive=True),
766
  }
767
  query_params = dict(request.query_params)
768
  if "dataset" in query_params:
@@ -777,30 +649,5 @@ with gr.Blocks(css=css) as demo:
777
  else:
778
  yield {search_page: gr.Column(visible=True)}
779
 
780
- def run_autorun():
781
- while True:
782
- # Using the value from autorun_active state
783
- if autorun_active.value:
784
- # Update full dataset
785
- full_data = autorun_iteration(
786
- search_bar.value,
787
- accumulated_datasets.value,
788
- compile_mode.value,
789
- processing_options.value
790
- )
791
- accumulated_display = gr.DataFrame(
792
- label="Accumulated Data (Last 50 Samples)",
793
- interactive=False,
794
- wrap=True
795
- )
796
- # Update state with full data and show last 50 rows
797
- accumulated_datasets.value = full_data
798
- yield {
799
- accumulated_display: update_display(full_data),
800
- accumulated_datasets: full_data
801
- }
802
- time.sleep(AUTORUN_INTERVAL)
803
- else:
804
- yield accumulated_display.update(visible=False)
805
 
806
- demo.queue(max_size=100).launch(share=True)
 
6
  from functools import partial
7
  from multiprocessing.pool import ThreadPool
8
  from queue import Queue, Empty
9
+ from typing import Callable, Iterable, Iterator, Optional, TypeVar
 
10
 
11
  import gradio as gr
12
  import pandas as pd
13
  import requests.exceptions
14
  from huggingface_hub import InferenceClient, create_repo, whoami, DatasetCard
15
 
16
+
17
  model_id = "microsoft/Phi-3-mini-4k-instruct"
18
  client = InferenceClient(model_id)
19
  save_dataset_hf_token = os.environ.get("SAVE_DATASET_HF_TOKEN")
20
 
 
 
21
  MAX_TOTAL_NB_ITEMS = 100 # almost infinite, don't judge me (actually it's because gradio needs a fixed number of components)
22
  MAX_NB_ITEMS_PER_GENERATION_CALL = 10
23
  NUM_ROWS = 100
 
24
  NUM_VARIANTS = 10
25
  NAMESPACE = "infinite-dataset-hub"
26
  URL = "https://huggingface.co/spaces/infinite-dataset-hub/infinite-dataset-hub"
 
 
 
 
27
 
28
  GENERATE_DATASET_NAMES_FOR_SEARCH_QUERY = (
29
+ "A Machine Learning Practioner is looking for a dataset that matches '{search_query}'. "
30
+ f"Generate a list of {MAX_NB_ITEMS_PER_GENERATION_CALL} names of quality datasets that don't exist but sound plausible and would "
31
+ "be helpful. Feel free to reuse words from the query '{search_query}' to name the datasets. "
32
+ "Every dataset should be about '{search_query}' and have descriptive tags/keywords including the ML task name associated with the dataset (classification, regression, anomaly detection, etc.). Use the following format:\n1. DatasetName1 (tag1, tag2, tag3)\n1. DatasetName2 (tag1, tag2, tag3)"
33
  )
34
 
35
  GENERATE_DATASET_CONTENT_FOR_SEARCH_QUERY_AND_NAME_AND_TAGS = (
 
49
  "expected",
50
  "common",
51
  "regular",
52
+ "unexpected but useful"
53
  "original but useful",
54
  "specific but not far-fetched",
55
  "uncommon but still plausible",
 
79
  - infinite-dataset-hub
80
  - synthetic
81
  ---
82
+
83
  {title}
84
+
85
  _Note: This is an AI-generated dataset so its content may be inaccurate or false_
86
+
87
  {content}
88
+
89
  **Source of the data:**
90
+
91
  The dataset was generated using the [Infinite Dataset Hub]({url}) and {model_id} using the query '{search_query}':
92
+
93
  - **Dataset Generation Page**: {dataset_url}
94
  - **Model**: https://huggingface.co/{model_id}
95
  - **More Datasets**: https://huggingface.co/datasets?other=infinite-dataset-hub
96
  """
97
 
98
  css = """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  a {
100
  color: var(--body-text-color);
101
  }
102
+
103
  .datasetButton {
104
  justify-content: start;
105
  justify-content: left;
 
149
  .insivibleButtonGroup {
150
  display: none;
151
  }
152
+
153
  @keyframes placeHolderShimmer{
154
  0%{
155
  background-position: -468px 0
 
177
  }
178
  """
179
 
180
+
181
  with gr.Blocks(css=css) as demo:
182
  generated_texts_state = gr.State((landing_page_datasets_generated_text,))
183
  with gr.Column() as search_page:
 
209
  buttons.append(gr.Button(dataset_name, elem_classes=dataset_name_classes))
210
  buttons.append(gr.Button(tags, elem_classes=tags_classes))
211
 
212
+ load_more_datasets = gr.Button("Load more datasets") # TODO: dosable when reaching end of page
213
  gr.Markdown(f"_powered by [{model_id}](https://huggingface.co/{model_id})_")
214
  with gr.Column(scale=4, min_width="200px"):
215
  with gr.Accordion("Settings", open=False, elem_classes="settings"):
 
234
  dataset_share_textbox = gr.Textbox(visible=False, show_copy_button=True, label="Copy this URL:", interactive=False, show_label=True)
235
  back_button = gr.Button("< Back", size="sm")
236
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
  ###################################
238
  #
239
  # Utils
 
247
  while batch := list(islice(it, n)):
248
  yield batch
249
 
250
+
251
  def stream_reponse(msg: str, generated_texts: tuple[str] = (), max_tokens=500) -> Iterator[str]:
252
  messages = [
253
  {"role": "user", "content": msg}
 
270
  ):
271
  yield message.choices[0].delta.content
272
  except requests.exceptions.ConnectionError as e:
273
+ print(e + "\n\nRetrying in 1sec")
274
  time.sleep(1)
275
  continue
276
  break
277
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
 
279
  def gen_datasets_line_by_line(search_query: str, generated_texts: tuple[str] = ()) -> Iterator[str]:
280
  search_query = search_query or ""
 
294
  generated_text += current_line
295
  print("-----\n\n" + generated_text)
296
 
297
+
298
  def gen_dataset_content(search_query: str, dataset_name: str, tags: str) -> Iterator[str]:
299
  search_query = search_query or ""
300
  search_query = search_query[:1000] if search_query.strip() else ""
 
308
  yield generated_text
309
  print("-----\n\n" + generated_text)
310
 
311
+
312
  def _write_generator_to_queue(queue: Queue, func: Callable[..., Iterable], kwargs: dict) -> None:
313
  for i, result in enumerate(func(**kwargs)):
314
  queue.put(result)
315
  return None
316
 
317
+
318
  def iflatmap_unordered(
319
  func: Callable[..., Iterable[T]],
320
  *,
 
336
  # we get the result in case there's an error to raise
337
  [async_result.get(timeout=0.05) for async_result in async_results]
338
 
339
+
340
  def generate_partial_dataset(title: str, content: str, search_query: str, variant: str, csv_header: str, output: list[dict[str, str]], indices_to_generate: list[int], max_tokens=1500) -> Iterator[int]:
341
  dataset_name, tags = title.strip("# ").split("\ntags:", 1)
342
  dataset_name, tags = dataset_name.strip(), tags.strip()
 
384
  pass
385
  current_line = ""
386
  except requests.exceptions.ConnectionError as e:
387
+ print(e + "\n\nRetrying in 1sec")
388
  time.sleep(1)
389
  continue
390
  break
391
+ # for debugging
392
+ # with open(f".output{indices_to_generate[0]}.txt", "w") as f:
393
+ # f.write(generated_text)
394
+
395
 
396
  def generate_variants(preview_df: pd.DataFrame):
397
  label_candidate_columns = [column for column in preview_df.columns if "label" in column.lower()]
 
408
  for rarity in LONG_RARITIES
409
  ]
410
 
411
+
412
  def parse_preview_df(content: str) -> tuple[str, pd.DataFrame]:
413
  _in_csv = False
414
  csv = "\n".join(
 
420
  raise gr.Error("Failed to parse CSV Preview")
421
  return csv.split("\n")[0], parse_csv_df(csv)
422
 
423
+
424
  def parse_csv_df(csv: str, csv_header: Optional[str] = None) -> pd.DataFrame:
425
  # Fix generation mistake when providing a list that is not in quotes
426
  for match in re.finditer(r'''(?!")\[(["'][\w ]+["'][, ]*)+\](?!")''', csv):
 
433
  df = pd.read_csv(io.StringIO(csv), skipinitialspace=True)
434
  return df
435
 
436
+
437
  ###################################
438
  #
439
  # Buttons
440
  #
441
  ###################################
442
 
443
+
444
  def _search_datasets(search_query):
445
  yield {generated_texts_state: []}
446
  yield {
 
476
  }
477
  current_item_idx += 1
478
 
479
+
480
  @search_button.click(inputs=search_bar, outputs=button_groups + buttons + [generated_texts_state])
481
  def search_dataset_from_search_button(search_query):
482
  yield from _search_datasets(search_query)
483
 
484
+
485
  @search_bar.submit(inputs=search_bar, outputs=button_groups + buttons + [generated_texts_state])
486
  def search_dataset_from_search_bar(search_query):
487
  yield from _search_datasets(search_query)
488
 
489
+
490
  @load_more_datasets.click(inputs=[search_bar, generated_texts_state], outputs=button_groups + buttons + [generated_texts_state])
491
  def search_more_datasets(search_query, generated_texts):
492
  current_item_idx = initial_item_idx = len(generated_texts) * MAX_NB_ITEMS_PER_GENERATION_CALL
 
504
  try:
505
  dataset_name, tags = line.strip().split(".", 1)[1].strip(" )").split(" (", 1)
506
  except ValueError:
507
+ dataset_name, tags = line.strip().split(".", 1)[1].strip(" )").split(" ", 1) [0], ""
508
  dataset_name, tags = dataset_name.strip("()[]* "), tags.strip("()[]* ")
509
  generated_text += line
510
  yield {
 
514
  }
515
  current_item_idx += 1
516
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
517
  def _show_dataset(search_query, dataset_name, tags):
518
  yield {
519
+ search_page: gr.Column(visible=False),
520
+ dataset_page: gr.Column(visible=True),
521
  dataset_title: f"# {dataset_name}\n\n tags: {tags}",
522
+ dataset_share_textbox: gr.Textbox(visible=False),
523
+ dataset_dataframe: gr.DataFrame(visible=False),
524
+ generate_full_dataset_button: gr.Button(interactive=True),
525
+ save_dataset_button: gr.Button(visible=False),
526
+ open_dataset_message: gr.Markdown(visible=False)
527
  }
528
  for generated_text in gen_dataset_content(search_query=search_query, dataset_name=dataset_name, tags=tags):
529
  yield {dataset_content: generated_text}
530
 
531
+
532
  show_dataset_inputs = [search_bar, *buttons]
533
  show_dataset_outputs = [search_page, dataset_page, dataset_title, dataset_content, generate_full_dataset_button, dataset_dataframe, save_dataset_button, open_dataset_message, dataset_share_textbox]
534
  scroll_to_top_js = """
 
543
  }
544
  """
545
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
546
  def show_dataset_from_button(search_query, *buttons_values, i):
547
  dataset_name, tags = buttons_values[2 * i : 2 * i + 2]
548
  yield from _show_dataset(search_query, dataset_name, tags)
 
551
  dataset_name_button.click(partial(show_dataset_from_button, i=i), inputs=show_dataset_inputs, outputs=show_dataset_outputs, js=scroll_to_top_js)
552
  tags_button.click(partial(show_dataset_from_button, i=i), inputs=show_dataset_inputs, outputs=show_dataset_outputs, js=scroll_to_top_js)
553
 
554
+
555
  @back_button.click(outputs=[search_page, dataset_page], js=scroll_to_top_js)
556
  def show_search_page():
557
  return gr.Column(visible=True), gr.Column(visible=False)
558
 
559
+
560
  @generate_full_dataset_button.click(inputs=[dataset_title, dataset_content, search_bar, select_namespace_dropdown, visibility_radio], outputs=[dataset_dataframe, generate_full_dataset_button, save_dataset_button])
561
  def generate_full_dataset(title, content, search_query, namespace, visability):
562
  dataset_name, tags = title.strip("# ").split("\ntags:", 1)
 
576
  output[:len(preview_df)] = [{"idx": i, **x} for i, x in enumerate(preview_df.to_dict(orient="records"))]
577
  yield {
578
  dataset_dataframe: gr.DataFrame(pd.DataFrame([{"idx": i, **x} for i, x in enumerate(output) if x]), visible=True),
579
+ generate_full_dataset_button: gr.Button(interactive=False),
580
+ save_dataset_button: gr.Button(f"💾 Save Dataset {namespace}/{dataset_name}" + (" (private)" if visability != "public" else ""), visible=True, interactive=False)
581
  }
582
  kwargs_iterable = [
583
  {
 
593
  ]
594
  for _ in iflatmap_unordered(generate_partial_dataset, kwargs_iterable=kwargs_iterable):
595
  yield {dataset_dataframe: pd.DataFrame([{"idx": i, **{column_name: x.get(column_name) for column_name in columns}} for i, x in enumerate(output) if x])}
596
+ yield {save_dataset_button: gr.Button(interactive=True)}
597
  print(f"Generated {dataset_name}!")
598
 
599
+
600
  @save_dataset_button.click(inputs=[dataset_title, dataset_content, search_bar, dataset_dataframe, select_namespace_dropdown, visibility_radio], outputs=[save_dataset_button, open_dataset_message])
601
  def save_dataset(title: str, content: str, search_query: str, df: pd.DataFrame, namespace: str, visability: str, oauth_token: Optional[gr.OAuthToken]):
602
  dataset_name, tags = title.strip("# ").split("\ntags:", 1)
 
605
  repo_id = f"{namespace}/{dataset_name}"
606
  dataset_url = f"{URL}?q={search_query.replace(' ', '+')}&dataset={dataset_name.replace(' ', '+')}&tags={tags.replace(' ', '+')}"
607
  gr.Info("Saving dataset...")
608
+ yield {save_dataset_button: gr.Button(interactive=False)}
609
  create_repo(repo_id=repo_id, repo_type="dataset", private=visability!="public", exist_ok=True, token=token)
610
  df.to_csv(f"hf://datasets/{repo_id}/data.csv", storage_options={"token": token}, index=False)
611
  DatasetCard(DATASET_CARD_CONTENT.format(title=title, content=content, url=URL, dataset_url=dataset_url, model_id=model_id, search_query=search_query)).push_to_hub(repo_id=repo_id, repo_type="dataset", token=token)
612
  gr.Info(f"✅ Dataset saved at {repo_id}")
613
  additional_message = "PS: You can also save datasets under your account in the Settings ;)"
614
+ yield {open_dataset_message: gr.Markdown(f"# 🎉 Yay ! Your dataset has been saved to [{repo_id}](https://huggingface.co/datasets/{repo_id}) !\n\nDataset link: [https://huggingface.co/datasets/{repo_id}](https://huggingface.co/datasets/{repo_id})\n\n{additional_message}", visible=True)}
615
  print(f"Saved {dataset_name}!")
616
 
617
+
618
  @dataset_share_button.click(inputs=[dataset_title, search_bar], outputs=[dataset_share_textbox])
619
  def show_dataset_url(title, search_query):
620
  dataset_name, tags = title.strip("# ").split("\ntags:", 1)
621
  dataset_name, tags = dataset_name.strip(), tags.strip()
622
+ return gr.Textbox(
623
+ f"{URL}?q={search_query.replace(' ', '+')}&dataset={dataset_name.replace(' ', '+')}&tags={tags.replace(' ', '+')}",
624
+ visible=True,
625
+ )
626
 
627
  @demo.load(outputs=show_dataset_outputs + button_groups + buttons + [generated_texts_state] + [select_namespace_dropdown, visibility_radio])
628
  def load_app(request: gr.Request, oauth_token: Optional[gr.OAuthToken]):
629
  if oauth_token:
630
  user_info = whoami(oauth_token.token)
631
  yield {
632
+ select_namespace_dropdown: gr.Dropdown(
633
  choices=[user_info["name"]] + [org_info["name"] for org_info in user_info["orgs"]],
634
  value=user_info["name"],
635
  visible=True,
636
  ),
637
+ visibility_radio: gr.Radio(interactive=True),
638
  }
639
  query_params = dict(request.query_params)
640
  if "dataset" in query_params:
 
649
  else:
650
  yield {search_page: gr.Column(visible=True)}
651
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
652
 
653
+ demo.launch()