Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		update-layout-add-evaluation (#17)
Browse files- add comment to divide functions/ui (c2fbbc311616f60f3dbc1e546d20517329b41486)
- fix typo (3ef1fed40b702eb482f9017241a2d3641cde3946)
- move sign in button to another column (ea29202ac74e3b567f8bbc83f2dd01edeb7e00b5)
- make sign in button smaller (3b5e775206b9e2276d8d5afd14d5a2ae6eb0f852)
- remove repeated import (9c1769a069aa5cadade2e120ebbfbb3524b4a71c)
- move sign in button to the right (5d91425bc46bbed12e76279e32eed828738f2e78)
- modify column width and typos (2b5c2e3fa953fcd1e7bbc5d50fa1eaa18cf51a7f)
- update successful message and pipeline code (4234ad816ad254da4dc0a2bdf4f6ee901f3ab647)
- update dataframe visualizations (7350fc6b3cdabd3c88562f7ebea772ea936b293b)
- update text and order parameter layout (45693e1d9d5340197d0a5298329a1b176836e5e9)
- typo (2673ebc69f8b3ee53ca0bea400abe8b18dcec6c7)
- add temperature for system prompt (857f1ba71f10ddb10f045923601746daed130b19)
- update textcat (separate prompt and labels) and use input parameters (4e193106207eda3f59650448038a680c25075972)
- update sft and use input parameters (dea11022bc5c78e08481e4e90bbb73b0402cdadc)
- update push dataset (49d5948eb076fc8b3354a9d4acdaac477fc0c398)
- add evaluation task (34371d30aa99cdb709c9def84739ab3b8b7fa611)
- hide pipeline ui each time it generates (c26510fcff621c6a144917e1a56d5f87dd41fd41)
- move order hide pipeline ui (1b00519115b913bff86a6f2ba061f97eb860e78a)
- merge remote tracking branch (1c412e2113c3889b13572af931a3be19fc93df5a)
- app.py +3 -3
- pyproject.toml +1 -1
- src/distilabel_dataset_generator/_tabbedinterface.py +4 -2
- src/distilabel_dataset_generator/apps/base.py +16 -33
- src/distilabel_dataset_generator/apps/eval.py +687 -202
- src/distilabel_dataset_generator/apps/sft.py +102 -47
- src/distilabel_dataset_generator/apps/textcat.py +171 -140
- src/distilabel_dataset_generator/pipelines/eval.py +205 -0
- src/distilabel_dataset_generator/pipelines/sft.py +50 -49
- src/distilabel_dataset_generator/pipelines/textcat.py +89 -70
- src/distilabel_dataset_generator/utils.py +97 -8
| @@ -3,6 +3,7 @@ import gradio as gr | |
| 3 | 
             
            from src.distilabel_dataset_generator._tabbedinterface import TabbedInterface
         | 
| 4 | 
             
            from src.distilabel_dataset_generator.apps.faq import app as faq_app
         | 
| 5 | 
             
            from src.distilabel_dataset_generator.apps.sft import app as sft_app
         | 
|  | |
| 6 | 
             
            from src.distilabel_dataset_generator.apps.textcat import app as textcat_app
         | 
| 7 |  | 
| 8 | 
             
            theme ='argilla/argilla-theme'
         | 
| @@ -25,12 +26,11 @@ button.hf-login:hover {background: var(--neutral-700); color: white} | |
| 25 | 
             
            """
         | 
| 26 |  | 
| 27 | 
             
            demo = TabbedInterface(
         | 
| 28 | 
            -
                [textcat_app, sft_app, faq_app],
         | 
| 29 | 
            -
                ["Text Classification", "Supervised Fine-Tuning", "FAQ"],
         | 
| 30 | 
             
                css=css,
         | 
| 31 | 
             
                title="""
         | 
| 32 | 
             
                <h1>Synthetic Data Generator</h1>
         | 
| 33 | 
            -
                <h3>Build datasets using natural language</h3>
         | 
| 34 | 
             
                """,
         | 
| 35 | 
             
                head="Synthetic Data Generator",
         | 
| 36 | 
             
                theme=theme,
         | 
|  | |
| 3 | 
             
            from src.distilabel_dataset_generator._tabbedinterface import TabbedInterface
         | 
| 4 | 
             
            from src.distilabel_dataset_generator.apps.faq import app as faq_app
         | 
| 5 | 
             
            from src.distilabel_dataset_generator.apps.sft import app as sft_app
         | 
| 6 | 
            +
            from src.distilabel_dataset_generator.apps.eval import app as eval_app
         | 
| 7 | 
             
            from src.distilabel_dataset_generator.apps.textcat import app as textcat_app
         | 
| 8 |  | 
| 9 | 
             
            theme ='argilla/argilla-theme'
         | 
|  | |
| 26 | 
             
            """
         | 
| 27 |  | 
| 28 | 
             
            demo = TabbedInterface(
         | 
| 29 | 
            +
                [textcat_app, sft_app, eval_app, faq_app],
         | 
| 30 | 
            +
                ["Text Classification", "Supervised Fine-Tuning", "Evaluation", "FAQ"],
         | 
| 31 | 
             
                css=css,
         | 
| 32 | 
             
                title="""
         | 
| 33 | 
             
                <h1>Synthetic Data Generator</h1>
         | 
|  | |
| 34 | 
             
                """,
         | 
| 35 | 
             
                head="Synthetic Data Generator",
         | 
| 36 | 
             
                theme=theme,
         | 
| @@ -6,7 +6,7 @@ authors = [ | |
| 6 | 
             
                {name = "davidberenstein1957", email = "[email protected]"},
         | 
| 7 | 
             
            ]
         | 
| 8 | 
             
            dependencies = [
         | 
| 9 | 
            -
                "distilabel[hf-inference-endpoints,argilla,outlines]>=1.4.1",
         | 
| 10 | 
             
                "gradio[oauth]<5.0.0",
         | 
| 11 | 
             
                "transformers>=4.44.2",
         | 
| 12 | 
             
                "sentence-transformers>=3.2.0",
         | 
|  | |
| 6 | 
             
                {name = "davidberenstein1957", email = "[email protected]"},
         | 
| 7 | 
             
            ]
         | 
| 8 | 
             
            dependencies = [
         | 
| 9 | 
            +
                "distilabel[hf-inference-endpoints,argilla,outlines,instructor]>=1.4.1",
         | 
| 10 | 
             
                "gradio[oauth]<5.0.0",
         | 
| 11 | 
             
                "transformers>=4.44.2",
         | 
| 12 | 
             
                "sentence-transformers>=3.2.0",
         | 
| @@ -63,10 +63,12 @@ class TabbedInterface(Blocks): | |
| 63 | 
             
                        if title:
         | 
| 64 | 
             
                            HTML(value=title)
         | 
| 65 | 
             
                            with gr.Row():
         | 
| 66 | 
            -
                                with gr.Column(scale= | 
| 67 | 
            -
                                    gr. | 
| 68 | 
             
                                with gr.Column(scale=3):
         | 
| 69 | 
             
                                    pass
         | 
|  | |
|  | |
| 70 | 
             
                        with Tabs():
         | 
| 71 | 
             
                            for interface, tab_name in zip(interface_list, tab_names, strict=False):
         | 
| 72 | 
             
                                with Tab(label=tab_name):
         | 
|  | |
| 63 | 
             
                        if title:
         | 
| 64 | 
             
                            HTML(value=title)
         | 
| 65 | 
             
                            with gr.Row():
         | 
| 66 | 
            +
                                with gr.Column(scale=2):
         | 
| 67 | 
            +
                                    gr.Markdown("### Build datasets using natural language")
         | 
| 68 | 
             
                                with gr.Column(scale=3):
         | 
| 69 | 
             
                                    pass
         | 
| 70 | 
            +
                                with gr.Column(scale=2):
         | 
| 71 | 
            +
                                    gr.LoginButton(value="Sign in!", variant="hf-login", size="sm", scale=2)
         | 
| 72 | 
             
                        with Tabs():
         | 
| 73 | 
             
                            for interface, tab_name in zip(interface_list, tab_names, strict=False):
         | 
| 74 | 
             
                                with Tab(label=tab_name):
         | 
| @@ -15,7 +15,7 @@ from src.distilabel_dataset_generator.utils import ( | |
| 15 | 
             
                get_argilla_client,
         | 
| 16 | 
             
                get_login_button,
         | 
| 17 | 
             
                list_orgs,
         | 
| 18 | 
            -
                 | 
| 19 | 
             
            )
         | 
| 20 |  | 
| 21 | 
             
            TEXTCAT_TASK = "text_classification"
         | 
| @@ -137,7 +137,7 @@ def get_main_ui( | |
| 137 | 
             
                        show_progress=True,
         | 
| 138 | 
             
                    )
         | 
| 139 |  | 
| 140 | 
            -
                    app.load(fn= | 
| 141 | 
             
                    app.load(get_org_dropdown, outputs=[org_name])
         | 
| 142 |  | 
| 143 | 
             
                return (
         | 
| @@ -300,25 +300,6 @@ def get_iterate_on_sample_dataset_ui( | |
| 300 | 
             
                )
         | 
| 301 |  | 
| 302 |  | 
| 303 | 
            -
            def get_pipeline_code_ui(pipeline_code: str) -> gr.Code:
         | 
| 304 | 
            -
                gr.Markdown("## Customize and run with distilabel")
         | 
| 305 | 
            -
                gr.HTML("<hr>")
         | 
| 306 | 
            -
             | 
| 307 | 
            -
                with gr.Accordion(
         | 
| 308 | 
            -
                    "Run this pipeline using distilabel",
         | 
| 309 | 
            -
                    open=False,
         | 
| 310 | 
            -
                ):
         | 
| 311 | 
            -
                    gr.Markdown(
         | 
| 312 | 
            -
                        "You can run this pipeline locally with distilabel. For more information, please refer to the [distilabel documentation](https://distilabel.argilla.io/) or go to the FAQ tab at the top of the page for more information."
         | 
| 313 | 
            -
                    )
         | 
| 314 | 
            -
                    pipeline_code = gr.Code(
         | 
| 315 | 
            -
                        value=pipeline_code,
         | 
| 316 | 
            -
                        language="python",
         | 
| 317 | 
            -
                        label="Distilabel Pipeline Code",
         | 
| 318 | 
            -
                    )
         | 
| 319 | 
            -
                return pipeline_code
         | 
| 320 | 
            -
             | 
| 321 | 
            -
             | 
| 322 | 
             
            def get_argilla_tab() -> Tuple[Any]:
         | 
| 323 | 
             
                with gr.Tab(label="Argilla"):
         | 
| 324 | 
             
                    if get_argilla_client() is not None:
         | 
| @@ -492,7 +473,7 @@ def get_success_message_row() -> gr.Markdown: | |
| 492 | 
             
                return success_message
         | 
| 493 |  | 
| 494 |  | 
| 495 | 
            -
            def  | 
| 496 | 
             
                client = get_argilla_client()
         | 
| 497 | 
             
                argilla_api_url = client.api_url
         | 
| 498 | 
             
                return gr.Markdown(
         | 
| @@ -500,25 +481,27 @@ def show_success_message_hub(org_name, repo_name) -> gr.Markdown: | |
| 500 | 
             
                    <div style="padding: 1em; background-color: #e6f3e6; border-radius: 5px; margin-top: 1em;">
         | 
| 501 | 
             
                        <h3 style="color: #2e7d32; margin: 0;">Dataset Published Successfully!</h3>
         | 
| 502 | 
             
                        <p style="margin-top: 0.5em;">
         | 
| 503 | 
            -
                             | 
| 504 | 
            -
             | 
| 505 | 
            -
             | 
| 506 | 
            -
             | 
|  | |
| 507 | 
             
                        </p>
         | 
| 508 | 
             
                        <p style="margin-top: 0.5em;">
         | 
| 509 | 
            -
                            Your dataset is now available  | 
| 510 | 
            -
                            <a href="{ | 
| 511 | 
            -
                                { | 
| 512 | 
             
                            </a>
         | 
| 513 | 
            -
                            <br>Unfamiliar with Argilla? Here are some docs to help you get started:
         | 
| 514 | 
            -
                            <br>โข <a href="https://docs.argilla.io/latest/how_to_guides/annotate/" target="_blank">How to curate data in Argilla</a>
         | 
| 515 | 
            -
                            <br>โข <a href="https://docs.argilla.io/latest/how_to_guides/import_export/" target="_blank">How to export data once you have reviewed the dataset</a>
         | 
| 516 | 
             
                        </p>
         | 
| 517 | 
             
                    </div>
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 518 | 
             
                    """,
         | 
| 519 | 
             
                    visible=True,
         | 
| 520 | 
             
                )
         | 
| 521 |  | 
| 522 | 
            -
             | 
| 523 | 
             
            def hide_success_message() -> gr.Markdown:
         | 
| 524 | 
             
                return gr.Markdown(value="")
         | 
|  | |
| 15 | 
             
                get_argilla_client,
         | 
| 16 | 
             
                get_login_button,
         | 
| 17 | 
             
                list_orgs,
         | 
| 18 | 
            +
                swap_visibility,
         | 
| 19 | 
             
            )
         | 
| 20 |  | 
| 21 | 
             
            TEXTCAT_TASK = "text_classification"
         | 
|  | |
| 137 | 
             
                        show_progress=True,
         | 
| 138 | 
             
                    )
         | 
| 139 |  | 
| 140 | 
            +
                    app.load(fn=swap_visibility, outputs=main_ui)
         | 
| 141 | 
             
                    app.load(get_org_dropdown, outputs=[org_name])
         | 
| 142 |  | 
| 143 | 
             
                return (
         | 
|  | |
| 300 | 
             
                )
         | 
| 301 |  | 
| 302 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 303 | 
             
            def get_argilla_tab() -> Tuple[Any]:
         | 
| 304 | 
             
                with gr.Tab(label="Argilla"):
         | 
| 305 | 
             
                    if get_argilla_client() is not None:
         | 
|  | |
| 473 | 
             
                return success_message
         | 
| 474 |  | 
| 475 |  | 
| 476 | 
            +
            def show_success_message(org_name, repo_name) -> gr.Markdown:
         | 
| 477 | 
             
                client = get_argilla_client()
         | 
| 478 | 
             
                argilla_api_url = client.api_url
         | 
| 479 | 
             
                return gr.Markdown(
         | 
|  | |
| 481 | 
             
                    <div style="padding: 1em; background-color: #e6f3e6; border-radius: 5px; margin-top: 1em;">
         | 
| 482 | 
             
                        <h3 style="color: #2e7d32; margin: 0;">Dataset Published Successfully!</h3>
         | 
| 483 | 
             
                        <p style="margin-top: 0.5em;">
         | 
| 484 | 
            +
                            <strong>
         | 
| 485 | 
            +
                                <a href="{argilla_api_url}" target="_blank" style="color: #1565c0; text-decoration: none;">
         | 
| 486 | 
            +
                                    Open your dataset in the Argilla space
         | 
| 487 | 
            +
                                </a>
         | 
| 488 | 
            +
                            </strong>
         | 
| 489 | 
             
                        </p>
         | 
| 490 | 
             
                        <p style="margin-top: 0.5em;">
         | 
| 491 | 
            +
                            The generated dataset is in the right format for fine-tuning with TRL, AutoTrain, or other frameworks. Your dataset is now available at: 
         | 
| 492 | 
            +
                            <a href="https://huggingface.co/datasets/{org_name}/{repo_name}" target="_blank" style="color: #1565c0; text-decoration: none;">
         | 
| 493 | 
            +
                                https://huggingface.co/datasets/{org_name}/{repo_name}
         | 
| 494 | 
             
                            </a>
         | 
|  | |
|  | |
|  | |
| 495 | 
             
                        </p>
         | 
| 496 | 
             
                    </div>
         | 
| 497 | 
            +
                    <p style="margin-top: 1em; font-size: 0.9em; color: #333;">
         | 
| 498 | 
            +
                        Unfamiliar with Argilla? Here are some docs to help you get started:
         | 
| 499 | 
            +
                        <br>โข <a href="https://docs.argilla.io/latest/how_to_guides/annotate/" target="_blank">How to curate data in Argilla</a>
         | 
| 500 | 
            +
                        <br>โข <a href="https://docs.argilla.io/latest/how_to_guides/import_export/" target="_blank">How to export data once you have reviewed the dataset</a>
         | 
| 501 | 
            +
                    </p>
         | 
| 502 | 
             
                    """,
         | 
| 503 | 
             
                    visible=True,
         | 
| 504 | 
             
                )
         | 
| 505 |  | 
|  | |
| 506 | 
             
            def hide_success_message() -> gr.Markdown:
         | 
| 507 | 
             
                return gr.Markdown(value="")
         | 
| @@ -1,70 +1,106 @@ | |
| 1 | 
             
            import json
         | 
|  | |
|  | |
| 2 |  | 
|  | |
| 3 | 
             
            import gradio as gr
         | 
|  | |
| 4 | 
             
            import pandas as pd
         | 
| 5 | 
            -
            from datasets import  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 6 | 
             
            from gradio_huggingfacehub_search import HuggingfaceHubSearch
         | 
|  | |
| 7 |  | 
| 8 | 
            -
            from src.distilabel_dataset_generator. | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 9 |  | 
| 10 |  | 
| 11 | 
            -
            def get_iframe(hub_repo_id) -> str:
         | 
| 12 | 
             
                if not hub_repo_id:
         | 
| 13 | 
            -
                    raise gr.Error("Hub  | 
|  | |
| 14 | 
             
                url = f"https://huggingface.co/datasets/{hub_repo_id}/embed/viewer"
         | 
| 15 | 
             
                iframe = f"""
         | 
| 16 | 
             
                <iframe
         | 
| 17 | 
            -
             | 
| 18 | 
            -
             | 
| 19 | 
            -
             | 
| 20 | 
            -
             | 
| 21 | 
            -
            ></iframe>
         | 
| 22 | 
            -
            """
         | 
| 23 | 
             
                return iframe
         | 
| 24 |  | 
| 25 |  | 
| 26 | 
            -
            def get_valid_columns( | 
| 27 | 
            -
                 | 
| 28 | 
            -
                 | 
| 29 | 
            -
             | 
|  | |
|  | |
| 30 | 
             
                    if isinstance(sample_val, str) or (
         | 
| 31 | 
            -
                        isinstance(sample_val, list)
         | 
| 32 | 
            -
                        and all(isinstance(item, dict) for item in sample_val)
         | 
| 33 | 
             
                    ):
         | 
| 34 | 
            -
                         | 
| 35 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
| 36 |  | 
|  | |
| 37 |  | 
| 38 | 
            -
             | 
| 39 | 
            -
             | 
| 40 | 
            -
                if not  | 
| 41 | 
             
                    raise gr.Error("Hub repo id is required")
         | 
| 42 | 
            -
                 | 
| 43 | 
            -
                 | 
|  | |
| 44 | 
             
                ds = ds_dict[splits[0]]
         | 
| 45 | 
            -
                if  | 
| 46 | 
            -
                    ds = ds.select(range( | 
| 47 | 
            -
                 | 
| 48 | 
            -
                 | 
| 49 | 
            -
                valid_columns = get_valid_columns(df)
         | 
| 50 | 
             
                return (
         | 
| 51 | 
            -
                     | 
| 52 | 
            -
                    gr.Dropdown(choices= | 
| 53 | 
            -
                    gr.Dropdown(choices= | 
| 54 | 
            -
                    gr.Dropdown(choices=valid_columns, label="Response Column"),
         | 
| 55 | 
             
                )
         | 
| 56 |  | 
| 57 |  | 
| 58 | 
             
            def define_evaluation_aspects(task_type: str):
         | 
| 59 | 
            -
                if task_type == " | 
| 60 | 
            -
                    return gr.Dropdown(
         | 
| 61 | 
            -
                        value=["overall-rating"],
         | 
| 62 | 
            -
                        choices=["complexity", "quality"],
         | 
| 63 | 
            -
                        label="Evaluation Aspects",
         | 
| 64 | 
            -
                        multiselect=True,
         | 
| 65 | 
            -
                        interactive=True,
         | 
| 66 | 
            -
                    )
         | 
| 67 | 
            -
                elif task_type == "instruction-response":
         | 
| 68 | 
             
                    return gr.Dropdown(
         | 
| 69 | 
             
                        value=["overall-rating"],
         | 
| 70 | 
             
                        choices=["helpfulness", "truthfulness", "overall-rating", "honesty"],
         | 
| @@ -76,226 +112,635 @@ def define_evaluation_aspects(task_type: str): | |
| 76 | 
             
                    return gr.Dropdown(interactive=False, visible=False)
         | 
| 77 |  | 
| 78 |  | 
| 79 | 
            -
            def evaluate_instruction(df: pd.DataFrame, aspects: list[str], instruction_column: str):
         | 
| 80 | 
            -
                pass
         | 
| 81 | 
            -
             | 
| 82 | 
            -
             | 
| 83 | 
             
            def evaluate_instruction_response(
         | 
| 84 | 
            -
                 | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 85 | 
             
            ):
         | 
| 86 | 
            -
                 | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 87 |  | 
| 88 |  | 
| 89 | 
             
            def evaluate_custom(
         | 
| 90 | 
            -
                 | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 91 | 
             
            ):
         | 
| 92 | 
            -
                 | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 93 |  | 
|  | |
|  | |
|  | |
| 94 |  | 
| 95 | 
            -
             | 
| 96 | 
            -
                 | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 97 | 
             
                eval_type: str,
         | 
| 98 | 
            -
                aspects_instruction: list[str],
         | 
| 99 | 
            -
                instruction_column: str,
         | 
| 100 | 
             
                aspects_instruction_response: list[str],
         | 
| 101 | 
            -
                 | 
| 102 | 
            -
                 | 
| 103 | 
            -
                aspects_custom: list[str],
         | 
| 104 | 
             
                prompt_template: str,
         | 
| 105 | 
             
                structured_output: dict,
         | 
|  | |
|  | |
| 106 | 
             
            ):
         | 
| 107 | 
            -
                if eval_type == " | 
| 108 | 
            -
                     | 
| 109 | 
            -
             | 
| 110 | 
            -
             | 
| 111 | 
            -
                         | 
| 112 | 
            -
                         | 
| 113 | 
            -
                         | 
| 114 | 
            -
                         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 115 | 
             
                    )
         | 
| 116 | 
            -
                 | 
| 117 | 
            -
                    df = evaluate_custom(df, aspects_custom, prompt_template, structured_output)
         | 
| 118 | 
            -
                return df
         | 
| 119 |  | 
| 120 |  | 
| 121 | 
            -
            def  | 
| 122 | 
             
                repo_id: str,
         | 
| 123 | 
             
                eval_type: str,
         | 
| 124 | 
            -
                aspects_instruction: list[str],
         | 
| 125 | 
             
                aspects_instruction_response: list[str],
         | 
| 126 | 
            -
                aspects_custom: list[str],
         | 
| 127 | 
            -
                instruction_instruction: str,
         | 
| 128 | 
             
                instruction_instruction_response: str,
         | 
| 129 | 
             
                response_instruction_response: str,
         | 
| 130 | 
             
                prompt_template: str,
         | 
| 131 | 
             
                structured_output: dict,
         | 
| 132 | 
             
            ):
         | 
| 133 | 
            -
                 | 
| 134 | 
            -
                 | 
| 135 | 
            -
                     | 
| 136 | 
            -
                    eval_type,
         | 
| 137 | 
            -
                     | 
| 138 | 
            -
                     | 
| 139 | 
            -
                     | 
| 140 | 
            -
                     | 
| 141 | 
            -
                     | 
| 142 | 
            -
                     | 
| 143 | 
            -
                     | 
| 144 | 
            -
                    structured_output,
         | 
| 145 | 
             
                )
         | 
| 146 | 
            -
                return  | 
| 147 |  | 
| 148 |  | 
| 149 | 
            -
            def  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 150 | 
             
                org_name: str,
         | 
| 151 | 
             
                repo_name: str,
         | 
| 152 | 
             
                private: bool,
         | 
| 153 | 
            -
                 | 
| 154 | 
             
                original_repo_id: str,
         | 
| 155 | 
             
                eval_type: str,
         | 
| 156 | 
            -
                aspects_instruction: list[str],
         | 
| 157 | 
             
                aspects_instruction_response: list[str],
         | 
| 158 | 
            -
                aspects_custom: list[str],
         | 
| 159 | 
            -
                instruction_instruction: str,
         | 
| 160 | 
             
                instruction_instruction_response: str,
         | 
| 161 | 
             
                response_instruction_response: str,
         | 
| 162 | 
             
                prompt_template: str,
         | 
| 163 | 
             
                structured_output: dict,
         | 
| 164 | 
            -
             | 
| 165 | 
            -
                 | 
| 166 | 
            -
             | 
| 167 | 
            -
             | 
| 168 | 
            -
             | 
| 169 | 
            -
                     | 
| 170 | 
            -
                     | 
| 171 | 
            -
                    aspects_instruction_response,
         | 
| 172 | 
            -
                    instruction_instruction_response,
         | 
| 173 | 
            -
                    response_instruction_response,
         | 
| 174 | 
            -
                     | 
| 175 | 
            -
                     | 
| 176 | 
            -
                     | 
| 177 | 
             
                )
         | 
| 178 | 
            -
                 | 
| 179 | 
            -
             | 
| 180 | 
            -
             | 
| 181 | 
            -
             | 
| 182 | 
            -
             | 
| 183 | 
            -
             | 
| 184 | 
            -
             | 
| 185 | 
            -
                         | 
| 186 | 
            -
                             | 
| 187 | 
            -
             | 
| 188 | 
            -
             | 
| 189 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 190 | 
             
                        )
         | 
| 191 | 
            -
             | 
| 192 | 
            -
             | 
| 193 | 
            -
                         | 
| 194 | 
            -
             | 
| 195 | 
            -
             | 
| 196 | 
            -
             | 
| 197 | 
            -
             | 
| 198 | 
            -
             | 
| 199 | 
            -
                        eval_type = gr.Dropdown(
         | 
| 200 | 
            -
                            label="Evaluation Type",
         | 
| 201 | 
            -
                            choices=["instruction", "instruction-response", "custom-template"],
         | 
| 202 | 
            -
                            visible=False,
         | 
| 203 | 
             
                        )
         | 
| 204 | 
            -
             | 
| 205 | 
            -
             | 
| 206 | 
            -
             | 
| 207 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 208 | 
             
                            )
         | 
| 209 | 
            -
             | 
| 210 | 
            -
             | 
| 211 | 
            -
             | 
| 212 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 213 | 
             
                            )
         | 
| 214 | 
            -
             | 
| 215 | 
            -
             | 
| 216 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 217 | 
             
                            )
         | 
| 218 | 
            -
                             | 
| 219 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 220 | 
             
                            )
         | 
| 221 | 
            -
                             | 
| 222 | 
            -
                                 | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 223 | 
             
                            )
         | 
| 224 | 
            -
             | 
| 225 | 
            -
             | 
| 226 | 
            -
                                 | 
| 227 | 
            -
                                 | 
|  | |
|  | |
| 228 | 
             
                            )
         | 
| 229 | 
            -
             | 
| 230 | 
            -
             | 
| 231 | 
            -
             | 
| 232 | 
            -
             | 
| 233 | 
            -
             | 
| 234 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
| 235 | 
             
                                interactive=True,
         | 
| 236 | 
             
                            )
         | 
| 237 | 
            -
                             | 
| 238 | 
            -
                                label=" | 
| 239 | 
            -
                                value= | 
| 240 | 
            -
                                language="json",
         | 
| 241 | 
             
                                interactive=True,
         | 
|  | |
| 242 | 
             
                            )
         | 
| 243 | 
            -
                             | 
| 244 | 
            -
                                 | 
| 245 | 
            -
                                 | 
| 246 | 
            -
                                 | 
|  | |
| 247 | 
             
                            )
         | 
| 248 | 
            -
             | 
| 249 | 
            -
             | 
| 250 | 
            -
             | 
| 251 | 
            -
             | 
| 252 | 
            -
             | 
| 253 | 
            -
             | 
| 254 | 
            -
             | 
| 255 | 
            -
             | 
| 256 | 
            -
             | 
| 257 | 
            -
             | 
| 258 | 
            -
             | 
| 259 | 
            -
             | 
| 260 | 
            -
             | 
| 261 | 
            -
             | 
| 262 | 
            -
             | 
| 263 | 
            -
             | 
| 264 | 
            -
             | 
| 265 | 
            -
             | 
| 266 | 
            -
             | 
| 267 | 
            -
             | 
| 268 | 
            -
             | 
| 269 | 
            -
             | 
| 270 | 
            -
             | 
| 271 | 
            -
             | 
| 272 | 
            -
             | 
| 273 | 
            -
                            scale=1,
         | 
| 274 | 
            -
                        )
         | 
| 275 | 
            -
                        btn_push_to_hub = gr.Button("Push to Hub", variant="primary", scale=2)
         | 
| 276 | 
            -
                    with gr.Column(scale=3):
         | 
| 277 | 
            -
                        success_message = gr.Markdown(visible=False)
         | 
| 278 |  | 
| 279 | 
            -
                search_in.submit(get_iframe, inputs=search_in, outputs=search_out)
         | 
| 280 | 
             
                load_btn.click(
         | 
| 281 | 
            -
                    load_dataset_from_hub,
         | 
| 282 | 
             
                    inputs=[search_in],
         | 
| 283 | 
             
                    outputs=[
         | 
| 284 | 
             
                        dataframe,
         | 
| 285 | 
            -
                        instruction_instruction,
         | 
| 286 | 
             
                        instruction_instruction_response,
         | 
| 287 | 
             
                        response_instruction_response,
         | 
| 288 | 
             
                    ],
         | 
| 289 | 
             
                )
         | 
|  | |
| 290 | 
             
                btn_apply_to_sample_dataset.click(
         | 
| 291 | 
            -
                     | 
| 292 | 
             
                    inputs=[
         | 
| 293 | 
             
                        search_in,
         | 
| 294 | 
             
                        eval_type,
         | 
| 295 | 
            -
                        aspects_instruction,
         | 
| 296 | 
             
                        aspects_instruction_response,
         | 
| 297 | 
            -
                        aspects_custom,
         | 
| 298 | 
            -
                        instruction_instruction,
         | 
| 299 | 
             
                        instruction_instruction_response,
         | 
| 300 | 
             
                        response_instruction_response,
         | 
| 301 | 
             
                        prompt_template,
         | 
| @@ -303,24 +748,64 @@ with gr.Blocks() as app: | |
| 303 | 
             
                    ],
         | 
| 304 | 
             
                    outputs=dataframe,
         | 
| 305 | 
             
                )
         | 
|  | |
| 306 | 
             
                btn_push_to_hub.click(
         | 
| 307 | 
            -
                     | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 308 | 
             
                    inputs=[
         | 
| 309 | 
             
                        org_name,
         | 
| 310 | 
             
                        repo_name,
         | 
| 311 | 
             
                        private,
         | 
| 312 | 
            -
                         | 
| 313 | 
             
                        search_in,
         | 
| 314 | 
             
                        eval_type,
         | 
| 315 | 
            -
                        aspects_instruction,
         | 
| 316 | 
             
                        aspects_instruction_response,
         | 
| 317 | 
            -
                        aspects_custom,
         | 
| 318 | 
            -
                        instruction_instruction,
         | 
| 319 | 
             
                        instruction_instruction_response,
         | 
| 320 | 
             
                        response_instruction_response,
         | 
| 321 | 
             
                        prompt_template,
         | 
| 322 | 
             
                        structured_output,
         | 
| 323 | 
             
                    ],
         | 
| 324 | 
            -
                    outputs=success_message,
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 325 | 
             
                )
         | 
|  | |
|  | |
| 326 | 
             
                app.load(fn=get_org_dropdown, outputs=[org_name])
         | 
|  | |
| 1 | 
             
            import json
         | 
| 2 | 
            +
            import uuid
         | 
| 3 | 
            +
            from typing import Union
         | 
| 4 |  | 
| 5 | 
            +
            import argilla as rg
         | 
| 6 | 
             
            import gradio as gr
         | 
| 7 | 
            +
            import numpy as np
         | 
| 8 | 
             
            import pandas as pd
         | 
| 9 | 
            +
            from datasets import (
         | 
| 10 | 
            +
                Dataset,
         | 
| 11 | 
            +
                get_dataset_config_names,
         | 
| 12 | 
            +
                get_dataset_split_names,
         | 
| 13 | 
            +
                load_dataset,
         | 
| 14 | 
            +
            )
         | 
| 15 | 
            +
            from distilabel.distiset import Distiset
         | 
| 16 | 
             
            from gradio_huggingfacehub_search import HuggingfaceHubSearch
         | 
| 17 | 
            +
            from huggingface_hub import HfApi
         | 
| 18 |  | 
| 19 | 
            +
            from src.distilabel_dataset_generator.apps.base import (
         | 
| 20 | 
            +
                hide_success_message,
         | 
| 21 | 
            +
                show_success_message,
         | 
| 22 | 
            +
                validate_argilla_user_workspace_dataset,
         | 
| 23 | 
            +
                validate_push_to_hub,
         | 
| 24 | 
            +
            )
         | 
| 25 | 
            +
            from src.distilabel_dataset_generator.pipelines.base import (
         | 
| 26 | 
            +
                DEFAULT_BATCH_SIZE,
         | 
| 27 | 
            +
            )
         | 
| 28 | 
            +
            from src.distilabel_dataset_generator.pipelines.embeddings import (
         | 
| 29 | 
            +
                get_embeddings,
         | 
| 30 | 
            +
                get_sentence_embedding_dimensions,
         | 
| 31 | 
            +
            )
         | 
| 32 | 
            +
            from src.distilabel_dataset_generator.pipelines.eval import (
         | 
| 33 | 
            +
                generate_pipeline_code,
         | 
| 34 | 
            +
                get_custom_evaluator,
         | 
| 35 | 
            +
                get_ultrafeedback_evaluator,
         | 
| 36 | 
            +
            )
         | 
| 37 | 
            +
            from src.distilabel_dataset_generator.utils import (
         | 
| 38 | 
            +
                column_to_list,
         | 
| 39 | 
            +
                extract_column_names,
         | 
| 40 | 
            +
                get_argilla_client,
         | 
| 41 | 
            +
                get_org_dropdown,
         | 
| 42 | 
            +
                process_columns,
         | 
| 43 | 
            +
                swap_visibility,
         | 
| 44 | 
            +
                pad_or_truncate_list,
         | 
| 45 | 
            +
            )
         | 
| 46 |  | 
| 47 |  | 
| 48 | 
            +
            def get_iframe(hub_repo_id: str) -> str:
         | 
| 49 | 
             
                if not hub_repo_id:
         | 
| 50 | 
            +
                    raise gr.Error("Hub repository ID is required.")
         | 
| 51 | 
            +
             | 
| 52 | 
             
                url = f"https://huggingface.co/datasets/{hub_repo_id}/embed/viewer"
         | 
| 53 | 
             
                iframe = f"""
         | 
| 54 | 
             
                <iframe
         | 
| 55 | 
            +
                    src="{url}"
         | 
| 56 | 
            +
                    frameborder="0"
         | 
| 57 | 
            +
                    width="100%"
         | 
| 58 | 
            +
                    height="600px"
         | 
| 59 | 
            +
                ></iframe>
         | 
| 60 | 
            +
                """
         | 
| 61 | 
             
                return iframe
         | 
| 62 |  | 
| 63 |  | 
| 64 | 
            +
            def get_valid_columns(dataframe: pd.DataFrame):
         | 
| 65 | 
            +
                instruction_valid_columns = []
         | 
| 66 | 
            +
                response_valid_columns = []
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                for col in dataframe.columns:
         | 
| 69 | 
            +
                    sample_val = dataframe[col].iloc[0]
         | 
| 70 | 
             
                    if isinstance(sample_val, str) or (
         | 
| 71 | 
            +
                        isinstance(sample_val, (list, np.ndarray))
         | 
| 72 | 
            +
                        and all(isinstance(item, dict) and "role" in item for item in sample_val)
         | 
| 73 | 
             
                    ):
         | 
| 74 | 
            +
                        instruction_valid_columns.append(col)
         | 
| 75 | 
            +
                        response_valid_columns.append(col)
         | 
| 76 | 
            +
                    if isinstance(sample_val, (list, np.ndarray)) and all(
         | 
| 77 | 
            +
                        isinstance(item, str) for item in sample_val
         | 
| 78 | 
            +
                    ):
         | 
| 79 | 
            +
                        response_valid_columns.append(col)
         | 
| 80 |  | 
| 81 | 
            +
                return instruction_valid_columns, response_valid_columns
         | 
| 82 |  | 
| 83 | 
            +
             | 
| 84 | 
            +
            def load_dataset_from_hub(repo_id: str, num_rows: int = 10):
         | 
| 85 | 
            +
                if not repo_id:
         | 
| 86 | 
             
                    raise gr.Error("Hub repo id is required")
         | 
| 87 | 
            +
                subsets = get_dataset_config_names(repo_id)
         | 
| 88 | 
            +
                ds_dict = load_dataset(repo_id, subsets[0])
         | 
| 89 | 
            +
                splits = get_dataset_split_names(repo_id, subsets[0])
         | 
| 90 | 
             
                ds = ds_dict[splits[0]]
         | 
| 91 | 
            +
                if num_rows:
         | 
| 92 | 
            +
                    ds = ds.select(range(num_rows))
         | 
| 93 | 
            +
                dataframe = ds.to_pandas()
         | 
| 94 | 
            +
                instruction_valid_columns, response_valid_columns = get_valid_columns(dataframe)
         | 
|  | |
| 95 | 
             
                return (
         | 
| 96 | 
            +
                    dataframe,
         | 
| 97 | 
            +
                    gr.Dropdown(choices=instruction_valid_columns, label="Instruction column"),
         | 
| 98 | 
            +
                    gr.Dropdown(choices=response_valid_columns, label="Response column"),
         | 
|  | |
| 99 | 
             
                )
         | 
| 100 |  | 
| 101 |  | 
| 102 | 
             
            def define_evaluation_aspects(task_type: str):
         | 
| 103 | 
            +
                if task_type == "ultrafeedback":
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 104 | 
             
                    return gr.Dropdown(
         | 
| 105 | 
             
                        value=["overall-rating"],
         | 
| 106 | 
             
                        choices=["helpfulness", "truthfulness", "overall-rating", "honesty"],
         | 
|  | |
| 112 | 
             
                    return gr.Dropdown(interactive=False, visible=False)
         | 
| 113 |  | 
| 114 |  | 
|  | |
|  | |
|  | |
|  | |
| 115 | 
             
            def evaluate_instruction_response(
         | 
| 116 | 
            +
                dataframe: pd.DataFrame,
         | 
| 117 | 
            +
                aspects: list[str],
         | 
| 118 | 
            +
                instruction_column: str,
         | 
| 119 | 
            +
                response_columns: str,
         | 
| 120 | 
            +
                num_rows: int = 10,
         | 
| 121 | 
            +
                is_sample: bool = False,
         | 
| 122 | 
            +
                progress=gr.Progress(),
         | 
| 123 | 
             
            ):
         | 
| 124 | 
            +
                progress(0.0, desc="Evaluating instructions and responses")
         | 
| 125 | 
            +
                data = process_columns(dataframe, instruction_column, response_columns)
         | 
| 126 | 
            +
                num_generations = len(data[0]["generations"])
         | 
| 127 | 
            +
                evaluated_results = []
         | 
| 128 | 
            +
                for entry in data:
         | 
| 129 | 
            +
                    result_row = {
         | 
| 130 | 
            +
                        "instruction": entry["instruction"],
         | 
| 131 | 
            +
                        "generations": entry["generations"],
         | 
| 132 | 
            +
                    }
         | 
| 133 | 
            +
                    for aspect in aspects:
         | 
| 134 | 
            +
                        result_row[f"ratings_{aspect}"] = None
         | 
| 135 | 
            +
                        result_row[f"rationale_for_ratings_{aspect}"] = None
         | 
| 136 | 
            +
                        if aspect in ["truthfulness", "helpfulness"]:
         | 
| 137 | 
            +
                            result_row[f"type_{aspect}"] = None
         | 
| 138 | 
            +
                            result_row[f"rationale_for_type_{aspect}"] = None
         | 
| 139 | 
            +
                    result_row["model_name"] = None
         | 
| 140 | 
            +
                    evaluated_results.append(result_row)
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                batch_size = DEFAULT_BATCH_SIZE
         | 
| 143 | 
            +
                total_steps: int = len(aspects) * num_rows
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                # evaluate instructions and responses
         | 
| 146 | 
            +
                for aspect in aspects:
         | 
| 147 | 
            +
                    ultrafeedback_evaluator = get_ultrafeedback_evaluator(aspect, is_sample)
         | 
| 148 | 
            +
                    n_processed = 0
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                    while n_processed < num_rows:
         | 
| 151 | 
            +
                        progress(
         | 
| 152 | 
            +
                            (len(aspects) * n_processed) / total_steps,
         | 
| 153 | 
            +
                            total=total_steps,
         | 
| 154 | 
            +
                            desc=f"Evaluating aspect: {aspect}",
         | 
| 155 | 
            +
                        )
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                        remaining_rows = num_rows - n_processed
         | 
| 158 | 
            +
                        batch_size = min(batch_size, remaining_rows)
         | 
| 159 | 
            +
                        inputs = data[n_processed : n_processed + batch_size]
         | 
| 160 | 
            +
                        batch_results = list(ultrafeedback_evaluator.process(inputs=inputs))
         | 
| 161 | 
            +
                        for j, result in enumerate(batch_results[0]):
         | 
| 162 | 
            +
                            idx = n_processed + j
         | 
| 163 | 
            +
                            evaluated_results[idx][f"ratings_{aspect}"] = pad_or_truncate_list(
         | 
| 164 | 
            +
                                result.get("ratings"), num_generations
         | 
| 165 | 
            +
                            )
         | 
| 166 | 
            +
                            evaluated_results[idx]["model_name"] = result.get("model_name")
         | 
| 167 | 
            +
                            if aspect in ["truthfulness", "helpfulness"]:
         | 
| 168 | 
            +
                                evaluated_results[idx][f"type_{aspect}"] = pad_or_truncate_list(
         | 
| 169 | 
            +
                                    result.get("types"), num_generations
         | 
| 170 | 
            +
                                )
         | 
| 171 | 
            +
                                evaluated_results[idx][f"rationale_for_type_{aspect}"] = (
         | 
| 172 | 
            +
                                    pad_or_truncate_list(result.get("rationales"), num_generations)
         | 
| 173 | 
            +
                                )
         | 
| 174 | 
            +
                                evaluated_results[idx][f"rationale_for_ratings_{aspect}"] = (
         | 
| 175 | 
            +
                                    pad_or_truncate_list(
         | 
| 176 | 
            +
                                        result.get("rationales-for-ratings"), num_generations
         | 
| 177 | 
            +
                                    )
         | 
| 178 | 
            +
                                )
         | 
| 179 | 
            +
                            else:
         | 
| 180 | 
            +
                                evaluated_results[idx][f"rationale_for_ratings_{aspect}"] = (
         | 
| 181 | 
            +
                                    pad_or_truncate_list(result.get("rationales"), num_generations)
         | 
| 182 | 
            +
                                )
         | 
| 183 | 
            +
                        n_processed += batch_size
         | 
| 184 | 
            +
             | 
| 185 | 
            +
                # create final dataset
         | 
| 186 | 
            +
                dataframe = pd.DataFrame(evaluated_results)
         | 
| 187 | 
            +
                progress(1.0, desc="Dataset evaluation completed")
         | 
| 188 | 
            +
                return dataframe
         | 
| 189 |  | 
| 190 |  | 
| 191 | 
             
            def evaluate_custom(
         | 
| 192 | 
            +
                dataframe: pd.DataFrame,
         | 
| 193 | 
            +
                prompt_template: str,
         | 
| 194 | 
            +
                structured_output: dict,
         | 
| 195 | 
            +
                num_rows: int = 10,
         | 
| 196 | 
            +
                is_sample: bool = False,
         | 
| 197 | 
            +
                progress=gr.Progress(),
         | 
| 198 | 
             
            ):
         | 
| 199 | 
            +
                progress(0.0, desc="Evaluating dataset")
         | 
| 200 | 
            +
                columns = extract_column_names(prompt_template)
         | 
| 201 | 
            +
                input_columns = {column: column_to_list(dataframe, column) for column in columns}
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                custom_evaluator = get_custom_evaluator(
         | 
| 204 | 
            +
                    prompt_template, structured_output, columns, is_sample
         | 
| 205 | 
            +
                )
         | 
| 206 | 
            +
                batch_size = DEFAULT_BATCH_SIZE
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                # evaluate the data
         | 
| 209 | 
            +
                n_processed = 0
         | 
| 210 | 
            +
                evaluation_results = []
         | 
| 211 | 
            +
                while n_processed < num_rows:
         | 
| 212 | 
            +
                    progress(
         | 
| 213 | 
            +
                        n_processed / num_rows,
         | 
| 214 | 
            +
                        desc="Evaluating dataset",
         | 
| 215 | 
            +
                    )
         | 
| 216 | 
            +
                    remaining_rows = num_rows - n_processed
         | 
| 217 | 
            +
                    batch_size = min(batch_size, remaining_rows)
         | 
| 218 | 
            +
             | 
| 219 | 
            +
                    inputs = []
         | 
| 220 | 
            +
                    for idx in range(n_processed, n_processed + batch_size):
         | 
| 221 | 
            +
                        input = {column: input_columns[column][idx] for column in input_columns}
         | 
| 222 | 
            +
                        inputs.append(input)
         | 
| 223 |  | 
| 224 | 
            +
                    batch = list(custom_evaluator.process(inputs=inputs))
         | 
| 225 | 
            +
                    evaluation_results.extend(batch[0])
         | 
| 226 | 
            +
                    n_processed += batch_size
         | 
| 227 |  | 
| 228 | 
            +
                # create final dataset
         | 
| 229 | 
            +
                distiset_results = []
         | 
| 230 | 
            +
                for result in evaluation_results:
         | 
| 231 | 
            +
                    record = {key: result[key] for key in result if key != "distilabel_metadata"}
         | 
| 232 | 
            +
                    distiset_results.append(record)
         | 
| 233 | 
            +
             | 
| 234 | 
            +
                dataframe = pd.DataFrame(distiset_results)
         | 
| 235 | 
            +
                progress(1.0, desc="Dataset evaluation completed")
         | 
| 236 | 
            +
                return dataframe
         | 
| 237 | 
            +
             | 
| 238 | 
            +
             | 
| 239 | 
            +
            def _evaluate_dataset(
         | 
| 240 | 
            +
                dataframe: pd.DataFrame,
         | 
| 241 | 
             
                eval_type: str,
         | 
|  | |
|  | |
| 242 | 
             
                aspects_instruction_response: list[str],
         | 
| 243 | 
            +
                instruction_instruction_response: str,
         | 
| 244 | 
            +
                response_instruction_response: str,
         | 
|  | |
| 245 | 
             
                prompt_template: str,
         | 
| 246 | 
             
                structured_output: dict,
         | 
| 247 | 
            +
                num_rows: int = 10,
         | 
| 248 | 
            +
                is_sample: bool = False,
         | 
| 249 | 
             
            ):
         | 
| 250 | 
            +
                if eval_type == "ultrafeedback":
         | 
| 251 | 
            +
                    dataframe = evaluate_instruction_response(
         | 
| 252 | 
            +
                        dataframe=dataframe,
         | 
| 253 | 
            +
                        aspects=aspects_instruction_response,
         | 
| 254 | 
            +
                        instruction_column=instruction_instruction_response,
         | 
| 255 | 
            +
                        response_columns=response_instruction_response,
         | 
| 256 | 
            +
                        num_rows=num_rows,
         | 
| 257 | 
            +
                        is_sample=is_sample,
         | 
| 258 | 
            +
                    )
         | 
| 259 | 
            +
                else:
         | 
| 260 | 
            +
                    dataframe = evaluate_custom(
         | 
| 261 | 
            +
                        dataframe=dataframe,
         | 
| 262 | 
            +
                        prompt_template=prompt_template,
         | 
| 263 | 
            +
                        structured_output=structured_output,
         | 
| 264 | 
            +
                        num_rows=num_rows,
         | 
| 265 | 
            +
                        is_sample=is_sample,
         | 
| 266 | 
             
                    )
         | 
| 267 | 
            +
                return dataframe
         | 
|  | |
|  | |
| 268 |  | 
| 269 |  | 
| 270 | 
            +
            def evaluate_sample_dataset(
         | 
| 271 | 
             
                repo_id: str,
         | 
| 272 | 
             
                eval_type: str,
         | 
|  | |
| 273 | 
             
                aspects_instruction_response: list[str],
         | 
|  | |
|  | |
| 274 | 
             
                instruction_instruction_response: str,
         | 
| 275 | 
             
                response_instruction_response: str,
         | 
| 276 | 
             
                prompt_template: str,
         | 
| 277 | 
             
                structured_output: dict,
         | 
| 278 | 
             
            ):
         | 
| 279 | 
            +
                dataframe, _, _ = load_dataset_from_hub(repo_id, num_rows=10)
         | 
| 280 | 
            +
                dataframe = _evaluate_dataset(
         | 
| 281 | 
            +
                    dataframe=dataframe,
         | 
| 282 | 
            +
                    eval_type=eval_type,
         | 
| 283 | 
            +
                    aspects_instruction_response=aspects_instruction_response,
         | 
| 284 | 
            +
                    instruction_instruction_response=instruction_instruction_response,
         | 
| 285 | 
            +
                    response_instruction_response=response_instruction_response,
         | 
| 286 | 
            +
                    prompt_template=prompt_template,
         | 
| 287 | 
            +
                    structured_output=structured_output,
         | 
| 288 | 
            +
                    num_rows=10,
         | 
| 289 | 
            +
                    is_sample=True,
         | 
|  | |
| 290 | 
             
                )
         | 
| 291 | 
            +
                return dataframe
         | 
| 292 |  | 
| 293 |  | 
| 294 | 
            +
            def push_dataset_to_hub(
         | 
| 295 | 
            +
                dataframe: pd.DataFrame, org_name: str, repo_name: str, oauth_token, private
         | 
| 296 | 
            +
            ):
         | 
| 297 | 
            +
                repo_id = validate_push_to_hub(org_name, repo_name)
         | 
| 298 | 
            +
                distiset = Distiset({"default": Dataset.from_pandas(dataframe)})
         | 
| 299 | 
            +
                distiset.push_to_hub(
         | 
| 300 | 
            +
                    repo_id=repo_id,
         | 
| 301 | 
            +
                    private=private,
         | 
| 302 | 
            +
                    include_script=False,
         | 
| 303 | 
            +
                    token=oauth_token.token,
         | 
| 304 | 
            +
                    create_pr=False,
         | 
| 305 | 
            +
                )
         | 
| 306 | 
            +
             | 
| 307 | 
            +
             | 
| 308 | 
            +
            def push_dataset(
         | 
| 309 | 
             
                org_name: str,
         | 
| 310 | 
             
                repo_name: str,
         | 
| 311 | 
             
                private: bool,
         | 
| 312 | 
            +
                num_rows: int,
         | 
| 313 | 
             
                original_repo_id: str,
         | 
| 314 | 
             
                eval_type: str,
         | 
|  | |
| 315 | 
             
                aspects_instruction_response: list[str],
         | 
|  | |
|  | |
| 316 | 
             
                instruction_instruction_response: str,
         | 
| 317 | 
             
                response_instruction_response: str,
         | 
| 318 | 
             
                prompt_template: str,
         | 
| 319 | 
             
                structured_output: dict,
         | 
| 320 | 
            +
                oauth_token: Union[gr.OAuthToken, None] = None,
         | 
| 321 | 
            +
                progress=gr.Progress(),
         | 
| 322 | 
            +
            ) -> pd.DataFrame:
         | 
| 323 | 
            +
                dataframe, _, _ = load_dataset_from_hub(original_repo_id, num_rows=num_rows)
         | 
| 324 | 
            +
                dataframe = _evaluate_dataset(
         | 
| 325 | 
            +
                    dataframe=dataframe,
         | 
| 326 | 
            +
                    eval_type=eval_type,
         | 
| 327 | 
            +
                    aspects_instruction_response=aspects_instruction_response,
         | 
| 328 | 
            +
                    instruction_instruction_response=instruction_instruction_response,
         | 
| 329 | 
            +
                    response_instruction_response=response_instruction_response,
         | 
| 330 | 
            +
                    prompt_template=prompt_template,
         | 
| 331 | 
            +
                    structured_output=structured_output,
         | 
| 332 | 
            +
                    num_rows=num_rows,
         | 
| 333 | 
             
                )
         | 
| 334 | 
            +
                push_dataset_to_hub(dataframe, org_name, repo_name, oauth_token, private)
         | 
| 335 | 
            +
                try:
         | 
| 336 | 
            +
                    progress(0.1, desc="Setting up user and workspace")
         | 
| 337 | 
            +
                    client = get_argilla_client()
         | 
| 338 | 
            +
                    hf_user = HfApi().whoami(token=oauth_token.token)["name"]
         | 
| 339 | 
            +
                    if eval_type == "ultrafeedback":
         | 
| 340 | 
            +
                        num_generations = len((dataframe["generations"][0]))
         | 
| 341 | 
            +
                        fields = [
         | 
| 342 | 
            +
                            rg.ChatField(
         | 
| 343 | 
            +
                                name=f"chat_{i}",
         | 
| 344 | 
            +
                                title=f"Chat {i+1}",
         | 
| 345 | 
            +
                                description=f"User and assistant conversation for generation {i+1}",
         | 
| 346 | 
            +
                            )
         | 
| 347 | 
            +
                            for i in range(num_generations)
         | 
| 348 | 
            +
                        ]
         | 
| 349 | 
            +
                        questions = []
         | 
| 350 | 
            +
                        for i in range(num_generations):
         | 
| 351 | 
            +
                            for aspect in aspects_instruction_response:
         | 
| 352 | 
            +
                                questions.append(
         | 
| 353 | 
            +
                                    rg.RatingQuestion(
         | 
| 354 | 
            +
                                        name=f"ratings_{aspect}_{i}",
         | 
| 355 | 
            +
                                        values=list(range(11)),
         | 
| 356 | 
            +
                                        title=f"Ratings for {aspect} for response {i+1}",
         | 
| 357 | 
            +
                                        required=True,
         | 
| 358 | 
            +
                                    )
         | 
| 359 | 
            +
                                )
         | 
| 360 | 
            +
                                questions.append(
         | 
| 361 | 
            +
                                    rg.TextQuestion(
         | 
| 362 | 
            +
                                        name=f"rationale_for_ratings_{aspect}_{i}",
         | 
| 363 | 
            +
                                        title=f"Rationale for ratings for {aspect} for response {i+1}",
         | 
| 364 | 
            +
                                        required=False,
         | 
| 365 | 
            +
                                        use_markdown=True,
         | 
| 366 | 
            +
                                    )
         | 
| 367 | 
            +
                                )
         | 
| 368 | 
            +
                                if aspect in ["truthfulness", "helpfulness"]:
         | 
| 369 | 
            +
                                    questions.append(
         | 
| 370 | 
            +
                                        rg.RatingQuestion(
         | 
| 371 | 
            +
                                            name=f"type_{aspect}_{i}",
         | 
| 372 | 
            +
                                            values=list(range(1, 6)),
         | 
| 373 | 
            +
                                            title=f"The type of the response {i+1} for {aspect}",
         | 
| 374 | 
            +
                                            required=True,
         | 
| 375 | 
            +
                                        )
         | 
| 376 | 
            +
                                    )
         | 
| 377 | 
            +
                                    questions.append(
         | 
| 378 | 
            +
                                        rg.TextQuestion(
         | 
| 379 | 
            +
                                            name=f"rationale_for_type_{aspect}_{i}",
         | 
| 380 | 
            +
                                            title=f"Rationale for type of the response {i+1} for {aspect}",
         | 
| 381 | 
            +
                                            required=False,
         | 
| 382 | 
            +
                                            use_markdown=True,
         | 
| 383 | 
            +
                                        )
         | 
| 384 | 
            +
                                    )
         | 
| 385 | 
            +
                        metadata = [
         | 
| 386 | 
            +
                            rg.IntegerMetadataProperty(
         | 
| 387 | 
            +
                                name="instruction_length", title="Instruction length"
         | 
| 388 | 
            +
                            ),
         | 
| 389 | 
            +
                        ]
         | 
| 390 | 
            +
                        for i in range(num_generations):
         | 
| 391 | 
            +
                            metadata.append(
         | 
| 392 | 
            +
                                rg.IntegerMetadataProperty(
         | 
| 393 | 
            +
                                    name=f"response_{i}_length", title=f"Response {i+1} length"
         | 
| 394 | 
            +
                                )
         | 
| 395 | 
            +
                            )
         | 
| 396 | 
            +
                        vectors = [
         | 
| 397 | 
            +
                            rg.VectorField(
         | 
| 398 | 
            +
                                name="instruction_embeddings",
         | 
| 399 | 
            +
                                dimensions=get_sentence_embedding_dimensions(),
         | 
| 400 | 
            +
                            )
         | 
| 401 | 
            +
                        ]
         | 
| 402 | 
            +
                        settings = rg.Settings(
         | 
| 403 | 
            +
                            fields=fields,
         | 
| 404 | 
            +
                            questions=questions,
         | 
| 405 | 
            +
                            metadata=metadata,
         | 
| 406 | 
            +
                            vectors=vectors,
         | 
| 407 | 
            +
                            guidelines="Please review the conversation and provide an evaluation.",
         | 
| 408 | 
             
                        )
         | 
| 409 | 
            +
             | 
| 410 | 
            +
                        dataframe["instruction_length"] = dataframe["instruction"].apply(len)
         | 
| 411 | 
            +
                        for i in range(num_generations):
         | 
| 412 | 
            +
                            dataframe[f"response_{i}_length"] = dataframe["generations"].apply(
         | 
| 413 | 
            +
                                lambda gens: len(gens[i]) if i < len(gens) else 0
         | 
| 414 | 
            +
                            )
         | 
| 415 | 
            +
                        dataframe["instruction_embeddings"] = get_embeddings(
         | 
| 416 | 
            +
                            dataframe["instruction"].to_list()
         | 
|  | |
|  | |
|  | |
|  | |
| 417 | 
             
                        )
         | 
| 418 | 
            +
             | 
| 419 | 
            +
                        progress(0.5, desc="Creating dataset")
         | 
| 420 | 
            +
                        rg_dataset = client.datasets(name=repo_name, workspace=hf_user)
         | 
| 421 | 
            +
                        if rg_dataset is None:
         | 
| 422 | 
            +
                            rg_dataset = rg.Dataset(
         | 
| 423 | 
            +
                                name=repo_name,
         | 
| 424 | 
            +
                                workspace=hf_user,
         | 
| 425 | 
            +
                                settings=settings,
         | 
| 426 | 
            +
                                client=client,
         | 
| 427 | 
            +
                            )
         | 
| 428 | 
            +
                            rg_dataset = rg_dataset.create()
         | 
| 429 | 
            +
             | 
| 430 | 
            +
                        progress(0.7, desc="Pushing dataset to Argilla")
         | 
| 431 | 
            +
                        hf_dataset = Dataset.from_pandas(dataframe)
         | 
| 432 | 
            +
                        records = []
         | 
| 433 | 
            +
                        for sample in hf_dataset:
         | 
| 434 | 
            +
                            fields = {}
         | 
| 435 | 
            +
                            metadata = {"instruction_length": sample.get("instruction_length", 0)}
         | 
| 436 | 
            +
                            vectors = {
         | 
| 437 | 
            +
                                "instruction_embeddings": sample.get("instruction_embeddings", [])
         | 
| 438 | 
            +
                            }
         | 
| 439 | 
            +
                            suggestions = []
         | 
| 440 | 
            +
                            generations = sample.get("generations", [])
         | 
| 441 | 
            +
                            for i in range(num_generations):
         | 
| 442 | 
            +
                                fields[f"chat_{i}"] = [
         | 
| 443 | 
            +
                                    {"role": "user", "content": sample.get("instruction", "")},
         | 
| 444 | 
            +
                                    {"role": "assistant", "content": generations[i]},
         | 
| 445 | 
            +
                                ]
         | 
| 446 | 
            +
                                metadata[f"response_{i}_length"] = sample.get(
         | 
| 447 | 
            +
                                    f"response_{i}_length", 0
         | 
| 448 | 
            +
                                )
         | 
| 449 | 
            +
             | 
| 450 | 
            +
                                for aspect in aspects_instruction_response:
         | 
| 451 | 
            +
                                    ratings = sample.get(f"ratings_{aspect}", [])
         | 
| 452 | 
            +
                                    rationales = sample.get(f"rationale_for_ratings__{aspect}", [])
         | 
| 453 | 
            +
             | 
| 454 | 
            +
                                    rating_value = (
         | 
| 455 | 
            +
                                        ratings[i]
         | 
| 456 | 
            +
                                        if ratings and isinstance(ratings[i], int)
         | 
| 457 | 
            +
                                        else None
         | 
| 458 | 
            +
                                    )
         | 
| 459 | 
            +
                                    rationale_value = (
         | 
| 460 | 
            +
                                        rationales[i]
         | 
| 461 | 
            +
                                        if rationales and isinstance(rationales[i], str)
         | 
| 462 | 
            +
                                        else None
         | 
| 463 | 
            +
                                    )
         | 
| 464 | 
            +
             | 
| 465 | 
            +
                                    if rating_value is not None:
         | 
| 466 | 
            +
                                        suggestions.append(
         | 
| 467 | 
            +
                                            rg.Suggestion(
         | 
| 468 | 
            +
                                                question_name=f"ratings_{aspect}_{i}",
         | 
| 469 | 
            +
                                                value=rating_value,
         | 
| 470 | 
            +
                                            )
         | 
| 471 | 
            +
                                        )
         | 
| 472 | 
            +
                                    if rationale_value is not None:
         | 
| 473 | 
            +
                                        suggestions.append(
         | 
| 474 | 
            +
                                            rg.Suggestion(
         | 
| 475 | 
            +
                                                question_name=f"rationale_for_ratings_{aspect}_{i}",
         | 
| 476 | 
            +
                                                value=rationale_value,
         | 
| 477 | 
            +
                                            )
         | 
| 478 | 
            +
                                        )
         | 
| 479 | 
            +
             | 
| 480 | 
            +
                                    if aspect in ["truthfulness", "helpfulness"]:
         | 
| 481 | 
            +
                                        types = sample.get(f"type_{aspect}", [])
         | 
| 482 | 
            +
                                        rationale_types = sample.get(
         | 
| 483 | 
            +
                                            f"rationale_for_type_{aspect}", []
         | 
| 484 | 
            +
                                        )
         | 
| 485 | 
            +
             | 
| 486 | 
            +
                                        type_value = (
         | 
| 487 | 
            +
                                            types[i]
         | 
| 488 | 
            +
                                            if types and isinstance(types[i], int)
         | 
| 489 | 
            +
                                            else None
         | 
| 490 | 
            +
                                        )
         | 
| 491 | 
            +
                                        rationale_type_value = (
         | 
| 492 | 
            +
                                            rationale_types[i]
         | 
| 493 | 
            +
                                            if rationale_types
         | 
| 494 | 
            +
                                            and isinstance(rationale_types[i], str)
         | 
| 495 | 
            +
                                            else None
         | 
| 496 | 
            +
                                        )
         | 
| 497 | 
            +
                                        if type_value is not None:
         | 
| 498 | 
            +
                                            suggestions.append(
         | 
| 499 | 
            +
                                                rg.Suggestion(
         | 
| 500 | 
            +
                                                    question_name=f"type_{aspect}_{i}",
         | 
| 501 | 
            +
                                                    value=type_value,
         | 
| 502 | 
            +
                                                )
         | 
| 503 | 
            +
                                            )
         | 
| 504 | 
            +
                                        if rationale_type_value is not None:
         | 
| 505 | 
            +
                                            suggestions.append(
         | 
| 506 | 
            +
                                                rg.Suggestion(
         | 
| 507 | 
            +
                                                    question_name=f"rationale_for_type_{aspect}_{i}",
         | 
| 508 | 
            +
                                                    value=rationale_type_value,
         | 
| 509 | 
            +
                                                )
         | 
| 510 | 
            +
                                            )
         | 
| 511 | 
            +
                            records.append(
         | 
| 512 | 
            +
                                rg.Record(
         | 
| 513 | 
            +
                                    fields=fields,
         | 
| 514 | 
            +
                                    metadata=metadata,
         | 
| 515 | 
            +
                                    vectors=vectors,
         | 
| 516 | 
            +
                                    suggestions=suggestions,
         | 
| 517 | 
            +
                                )
         | 
| 518 | 
             
                            )
         | 
| 519 | 
            +
                        rg_dataset.records.log(records=records)
         | 
| 520 | 
            +
                        progress(1.0, desc="Dataset pushed to Argilla")
         | 
| 521 | 
            +
                    else:
         | 
| 522 | 
            +
                        columns = extract_column_names(prompt_template)
         | 
| 523 | 
            +
                        settings = rg.Settings(
         | 
| 524 | 
            +
                            fields=[
         | 
| 525 | 
            +
                                rg.TextField(
         | 
| 526 | 
            +
                                    name=column,
         | 
| 527 | 
            +
                                    title=column.capitalize(),
         | 
| 528 | 
            +
                                    description="The column content",
         | 
| 529 | 
            +
                                )
         | 
| 530 | 
            +
                                for column in columns
         | 
| 531 | 
            +
                            ],
         | 
| 532 | 
            +
                            questions=[
         | 
| 533 | 
            +
                                rg.TextQuestion(
         | 
| 534 | 
            +
                                    name="evaluation",
         | 
| 535 | 
            +
                                    title="Evaluation",
         | 
| 536 | 
            +
                                    description="The generated evaluation",
         | 
| 537 | 
            +
                                    use_markdown=True,
         | 
| 538 | 
            +
                                ),
         | 
| 539 | 
            +
                            ],
         | 
| 540 | 
            +
                            metadata=[
         | 
| 541 | 
            +
                                rg.IntegerMetadataProperty(
         | 
| 542 | 
            +
                                    name=f"{column}_length", title=f"{column.capitalize()} length"
         | 
| 543 | 
            +
                                )
         | 
| 544 | 
            +
                                for column in columns
         | 
| 545 | 
            +
                            ],
         | 
| 546 | 
            +
                            vectors=[
         | 
| 547 | 
            +
                                rg.VectorField(
         | 
| 548 | 
            +
                                    name=f"{column}_embeddings",
         | 
| 549 | 
            +
                                    dimensions=get_sentence_embedding_dimensions(),
         | 
| 550 | 
            +
                                )
         | 
| 551 | 
            +
                                for column in columns
         | 
| 552 | 
            +
                            ],
         | 
| 553 | 
            +
                            guidelines="Please review, correct and provide an accurate evaluation.",
         | 
| 554 | 
            +
                        )
         | 
| 555 | 
            +
                        for column in columns:
         | 
| 556 | 
            +
                            dataframe[f"{column}_length"] = dataframe[column].apply(len)
         | 
| 557 | 
            +
                            dataframe[f"{column}_embeddings"] = get_embeddings(dataframe[column])
         | 
| 558 | 
            +
             | 
| 559 | 
            +
                        progress(0.5, desc="Creating dataset")
         | 
| 560 | 
            +
                        rg_dataset = client.datasets(name=repo_name, workspace=hf_user)
         | 
| 561 | 
            +
                        if rg_dataset is None:
         | 
| 562 | 
            +
                            rg_dataset = rg.Dataset(
         | 
| 563 | 
            +
                                name=repo_name,
         | 
| 564 | 
            +
                                workspace=hf_user,
         | 
| 565 | 
            +
                                settings=settings,
         | 
| 566 | 
            +
                                client=client,
         | 
| 567 | 
             
                            )
         | 
| 568 | 
            +
                            rg_dataset = rg_dataset.create()
         | 
| 569 | 
            +
                        progress(0.7, desc="Pushing dataset to Argilla")
         | 
| 570 | 
            +
                        hf_dataset = Dataset.from_pandas(dataframe)
         | 
| 571 | 
            +
                        rg_dataset.records.log(
         | 
| 572 | 
            +
                            records=hf_dataset, mapping={"generation": "evaluation"}
         | 
| 573 | 
            +
                        )
         | 
| 574 | 
            +
                        progress(1.0, desc="Dataset pushed to Argilla")
         | 
| 575 | 
            +
                except Exception as e:
         | 
| 576 | 
            +
                    raise gr.Error(f"Error pushing dataset to Argilla: {e}")
         | 
| 577 | 
            +
                return ""
         | 
| 578 | 
            +
             | 
| 579 | 
            +
             | 
| 580 | 
            +
            def show_pipeline_code_visibility():
         | 
| 581 | 
            +
                return {pipeline_code_ui: gr.Accordion(visible=True)}
         | 
| 582 | 
            +
             | 
| 583 | 
            +
            def hide_pipeline_code_visibility():
         | 
| 584 | 
            +
                return {pipeline_code_ui: gr.Accordion(visible=False)}
         | 
| 585 | 
            +
             | 
| 586 | 
            +
             | 
| 587 | 
            +
            ######################
         | 
| 588 | 
            +
            # Gradio UI
         | 
| 589 | 
            +
            ######################
         | 
| 590 | 
            +
             | 
| 591 | 
            +
             | 
| 592 | 
            +
            with gr.Blocks() as app:
         | 
| 593 | 
            +
                with gr.Column() as main_ui:
         | 
| 594 | 
            +
                    gr.Markdown("## 1. Select your input dataset")
         | 
| 595 | 
            +
                    with gr.Row(equal_height=False):
         | 
| 596 | 
            +
                        with gr.Column(scale=1):
         | 
| 597 | 
            +
                            search_in = HuggingfaceHubSearch(
         | 
| 598 | 
            +
                                label="Search",
         | 
| 599 | 
            +
                                placeholder="Search for a dataset",
         | 
| 600 | 
            +
                                search_type="dataset",
         | 
| 601 | 
            +
                                sumbit_on_select=True,
         | 
| 602 | 
             
                            )
         | 
| 603 | 
            +
                            load_btn = gr.Button("Load dataset", variant="primary")
         | 
| 604 | 
            +
                        with gr.Column(scale=3):
         | 
| 605 | 
            +
                            search_out = gr.HTML(label="Dataset preview")
         | 
| 606 | 
            +
             | 
| 607 | 
            +
                    gr.HTML(value="<hr>")
         | 
| 608 | 
            +
                    gr.Markdown(value="## 2. Configure your task")
         | 
| 609 | 
            +
                    with gr.Row(equal_height=False):
         | 
| 610 | 
            +
                        with gr.Column(scale=1):
         | 
| 611 | 
            +
                            eval_type = gr.Dropdown(
         | 
| 612 | 
            +
                                label="Evaluation type",
         | 
| 613 | 
            +
                                choices=["ultrafeedback", "custom"],
         | 
| 614 | 
            +
                                value="ultrafeedback",
         | 
| 615 | 
            +
                                multiselect=False,
         | 
| 616 | 
            +
                                visible=False,
         | 
| 617 | 
             
                            )
         | 
| 618 | 
            +
                            with gr.Tab("ultrafeedback") as tab_instruction_response:
         | 
| 619 | 
            +
                                aspects_instruction_response = define_evaluation_aspects(
         | 
| 620 | 
            +
                                    "ultrafeedback"
         | 
| 621 | 
            +
                                )
         | 
| 622 | 
            +
                                instruction_instruction_response = gr.Dropdown(
         | 
| 623 | 
            +
                                    label="Instruction Column",
         | 
| 624 | 
            +
                                    interactive=True,
         | 
| 625 | 
            +
                                    multiselect=False,
         | 
| 626 | 
            +
                                    allow_custom_value=False,
         | 
| 627 | 
            +
                                )
         | 
| 628 | 
            +
                                response_instruction_response = gr.Dropdown(
         | 
| 629 | 
            +
                                    label="Response Column",
         | 
| 630 | 
            +
                                    interactive=True,
         | 
| 631 | 
            +
                                    multiselect=True,
         | 
| 632 | 
            +
                                    allow_custom_value=False,
         | 
| 633 | 
            +
                                )
         | 
| 634 | 
            +
                                tab_instruction_response.select(
         | 
| 635 | 
            +
                                    fn=lambda: "ultrafeedback",
         | 
| 636 | 
            +
                                    inputs=[],
         | 
| 637 | 
            +
                                    outputs=[eval_type],
         | 
| 638 | 
            +
                                )
         | 
| 639 | 
            +
                            with gr.Tab("custom") as tab_custom:
         | 
| 640 | 
            +
                                aspects_custom = define_evaluation_aspects("custom")
         | 
| 641 | 
            +
                                prompt_template = gr.Code(
         | 
| 642 | 
            +
                                    label="Prompt template",
         | 
| 643 | 
            +
                                    value="Evaluate {{column_1}} based on {{column_2}}.",
         | 
| 644 | 
            +
                                    language="markdown",
         | 
| 645 | 
            +
                                    interactive=True,
         | 
| 646 | 
            +
                                )
         | 
| 647 | 
            +
                                structured_output = gr.Code(
         | 
| 648 | 
            +
                                    label="Structured output",
         | 
| 649 | 
            +
                                    value=json.dumps(
         | 
| 650 | 
            +
                                        {
         | 
| 651 | 
            +
                                            "type": "object",
         | 
| 652 | 
            +
                                            "properties": {
         | 
| 653 | 
            +
                                                "quality": {"type": "integer"},
         | 
| 654 | 
            +
                                                "clarity": {"type": "integer"},
         | 
| 655 | 
            +
                                                "relevance": {"type": "integer"},
         | 
| 656 | 
            +
                                            },
         | 
| 657 | 
            +
                                        },
         | 
| 658 | 
            +
                                        indent=4,
         | 
| 659 | 
            +
                                    ),
         | 
| 660 | 
            +
                                    language="json",
         | 
| 661 | 
            +
                                    interactive=True,
         | 
| 662 | 
            +
                                )
         | 
| 663 | 
            +
                                tab_custom.select(
         | 
| 664 | 
            +
                                    fn=lambda: "custom",
         | 
| 665 | 
            +
                                    inputs=[],
         | 
| 666 | 
            +
                                    outputs=[eval_type],
         | 
| 667 | 
            +
                                )
         | 
| 668 | 
            +
                            btn_apply_to_sample_dataset = gr.Button(
         | 
| 669 | 
            +
                                "Refresh dataset", variant="secondary", size="sm"
         | 
| 670 | 
             
                            )
         | 
| 671 | 
            +
                        with gr.Column(scale=3):
         | 
| 672 | 
            +
                            dataframe = gr.Dataframe(
         | 
| 673 | 
            +
                                headers=["prompt", "completion", "evaluation"],
         | 
| 674 | 
            +
                                wrap=False,
         | 
| 675 | 
            +
                                height=500,
         | 
| 676 | 
            +
                                interactive=False,
         | 
| 677 | 
             
                            )
         | 
| 678 | 
            +
             | 
| 679 | 
            +
                    gr.HTML(value="<hr>")
         | 
| 680 | 
            +
                    gr.Markdown(value="## 3. Evaluate your dataset")
         | 
| 681 | 
            +
                    with gr.Row(equal_height=False):
         | 
| 682 | 
            +
                        with gr.Column(scale=2):
         | 
| 683 | 
            +
                            org_name = get_org_dropdown()
         | 
| 684 | 
            +
                            repo_name = gr.Textbox(
         | 
| 685 | 
            +
                                label="Repo name",
         | 
| 686 | 
            +
                                placeholder="dataset_name",
         | 
| 687 | 
            +
                                value=f"my-distiset-{str(uuid.uuid4())[:8]}",
         | 
| 688 | 
             
                                interactive=True,
         | 
| 689 | 
             
                            )
         | 
| 690 | 
            +
                            num_rows = gr.Number(
         | 
| 691 | 
            +
                                label="Number of rows",
         | 
| 692 | 
            +
                                value=10,
         | 
|  | |
| 693 | 
             
                                interactive=True,
         | 
| 694 | 
            +
                                scale=1,
         | 
| 695 | 
             
                            )
         | 
| 696 | 
            +
                            private = gr.Checkbox(
         | 
| 697 | 
            +
                                label="Private dataset",
         | 
| 698 | 
            +
                                value=False,
         | 
| 699 | 
            +
                                interactive=True,
         | 
| 700 | 
            +
                                scale=1,
         | 
| 701 | 
             
                            )
         | 
| 702 | 
            +
                            btn_push_to_hub = gr.Button("Push to Hub", variant="primary", scale=2)
         | 
| 703 | 
            +
                        with gr.Column(scale=3):
         | 
| 704 | 
            +
                            success_message = gr.Markdown(visible=True)
         | 
| 705 | 
            +
                            with gr.Accordion(
         | 
| 706 | 
            +
                                "Do you want to go further? Customize and run with Distilabel",
         | 
| 707 | 
            +
                                open=False,
         | 
| 708 | 
            +
                                visible=False,
         | 
| 709 | 
            +
                            ) as pipeline_code_ui:
         | 
| 710 | 
            +
                                code = generate_pipeline_code(
         | 
| 711 | 
            +
                                        repo_id=search_in.value,
         | 
| 712 | 
            +
                                        aspects=aspects_instruction_response.value,
         | 
| 713 | 
            +
                                        instruction_column=instruction_instruction_response,
         | 
| 714 | 
            +
                                        response_columns=response_instruction_response,
         | 
| 715 | 
            +
                                        prompt_template=prompt_template.value,
         | 
| 716 | 
            +
                                        structured_output=structured_output.value,
         | 
| 717 | 
            +
                                        num_rows=num_rows.value,
         | 
| 718 | 
            +
                                        eval_type=eval_type.value,
         | 
| 719 | 
            +
                                    )
         | 
| 720 | 
            +
                                pipeline_code = gr.Code(
         | 
| 721 | 
            +
                                    value=code,
         | 
| 722 | 
            +
                                    language="python",
         | 
| 723 | 
            +
                                    label="Distilabel Pipeline Code",
         | 
| 724 | 
            +
                                )
         | 
| 725 | 
            +
             | 
| 726 | 
            +
                search_in.submit(fn=get_iframe, inputs=search_in, outputs=search_out)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 727 |  | 
|  | |
| 728 | 
             
                load_btn.click(
         | 
| 729 | 
            +
                    fn=load_dataset_from_hub,
         | 
| 730 | 
             
                    inputs=[search_in],
         | 
| 731 | 
             
                    outputs=[
         | 
| 732 | 
             
                        dataframe,
         | 
|  | |
| 733 | 
             
                        instruction_instruction_response,
         | 
| 734 | 
             
                        response_instruction_response,
         | 
| 735 | 
             
                    ],
         | 
| 736 | 
             
                )
         | 
| 737 | 
            +
             | 
| 738 | 
             
                btn_apply_to_sample_dataset.click(
         | 
| 739 | 
            +
                    fn=evaluate_sample_dataset,
         | 
| 740 | 
             
                    inputs=[
         | 
| 741 | 
             
                        search_in,
         | 
| 742 | 
             
                        eval_type,
         | 
|  | |
| 743 | 
             
                        aspects_instruction_response,
         | 
|  | |
|  | |
| 744 | 
             
                        instruction_instruction_response,
         | 
| 745 | 
             
                        response_instruction_response,
         | 
| 746 | 
             
                        prompt_template,
         | 
|  | |
| 748 | 
             
                    ],
         | 
| 749 | 
             
                    outputs=dataframe,
         | 
| 750 | 
             
                )
         | 
| 751 | 
            +
             | 
| 752 | 
             
                btn_push_to_hub.click(
         | 
| 753 | 
            +
                    fn=validate_argilla_user_workspace_dataset,
         | 
| 754 | 
            +
                    inputs=[repo_name],
         | 
| 755 | 
            +
                    outputs=[success_message],
         | 
| 756 | 
            +
                    show_progress=True,
         | 
| 757 | 
            +
                ).then(
         | 
| 758 | 
            +
                    fn=validate_push_to_hub,
         | 
| 759 | 
            +
                    inputs=[org_name, repo_name],
         | 
| 760 | 
            +
                    outputs=[success_message],
         | 
| 761 | 
            +
                    show_progress=True,
         | 
| 762 | 
            +
                ).success(
         | 
| 763 | 
            +
                    fn=hide_success_message,
         | 
| 764 | 
            +
                    outputs=[success_message],
         | 
| 765 | 
            +
                    show_progress=True,
         | 
| 766 | 
            +
                ).success(
         | 
| 767 | 
            +
                    fn=hide_pipeline_code_visibility,
         | 
| 768 | 
            +
                    inputs=[],
         | 
| 769 | 
            +
                    outputs=[pipeline_code_ui],
         | 
| 770 | 
            +
                ).success(
         | 
| 771 | 
            +
                    fn=push_dataset,
         | 
| 772 | 
             
                    inputs=[
         | 
| 773 | 
             
                        org_name,
         | 
| 774 | 
             
                        repo_name,
         | 
| 775 | 
             
                        private,
         | 
| 776 | 
            +
                        num_rows,
         | 
| 777 | 
             
                        search_in,
         | 
| 778 | 
             
                        eval_type,
         | 
|  | |
| 779 | 
             
                        aspects_instruction_response,
         | 
|  | |
|  | |
| 780 | 
             
                        instruction_instruction_response,
         | 
| 781 | 
             
                        response_instruction_response,
         | 
| 782 | 
             
                        prompt_template,
         | 
| 783 | 
             
                        structured_output,
         | 
| 784 | 
             
                    ],
         | 
| 785 | 
            +
                    outputs=[success_message],
         | 
| 786 | 
            +
                    show_progress=True,
         | 
| 787 | 
            +
                ).success(
         | 
| 788 | 
            +
                    fn=show_success_message,
         | 
| 789 | 
            +
                    inputs=[org_name, repo_name],
         | 
| 790 | 
            +
                    outputs=[success_message],
         | 
| 791 | 
            +
                ).success(
         | 
| 792 | 
            +
                    fn=generate_pipeline_code,
         | 
| 793 | 
            +
                    inputs=[
         | 
| 794 | 
            +
                        search_in,
         | 
| 795 | 
            +
                        aspects_instruction_response,
         | 
| 796 | 
            +
                        instruction_instruction_response,
         | 
| 797 | 
            +
                        response_instruction_response,
         | 
| 798 | 
            +
                        prompt_template,
         | 
| 799 | 
            +
                        structured_output,
         | 
| 800 | 
            +
                        num_rows,
         | 
| 801 | 
            +
                        eval_type,
         | 
| 802 | 
            +
                    ],
         | 
| 803 | 
            +
                    outputs=[pipeline_code],
         | 
| 804 | 
            +
                ).success(
         | 
| 805 | 
            +
                    fn=show_pipeline_code_visibility,
         | 
| 806 | 
            +
                    inputs=[],
         | 
| 807 | 
            +
                    outputs=[pipeline_code_ui],
         | 
| 808 | 
             
                )
         | 
| 809 | 
            +
             | 
| 810 | 
            +
                app.load(fn=swap_visibility, outputs=main_ui)
         | 
| 811 | 
             
                app.load(fn=get_org_dropdown, outputs=[org_name])
         | 
| @@ -10,10 +10,8 @@ from distilabel.distiset import Distiset | |
| 10 | 
             
            from huggingface_hub import HfApi
         | 
| 11 |  | 
| 12 | 
             
            from src.distilabel_dataset_generator.apps.base import (
         | 
| 13 | 
            -
                get_argilla_client,
         | 
| 14 | 
            -
                get_pipeline_code_ui,
         | 
| 15 | 
             
                hide_success_message,
         | 
| 16 | 
            -
                 | 
| 17 | 
             
                validate_argilla_user_workspace_dataset,
         | 
| 18 | 
             
                validate_push_to_hub,
         | 
| 19 | 
             
            )
         | 
| @@ -26,7 +24,6 @@ from src.distilabel_dataset_generator.pipelines.embeddings import ( | |
| 26 | 
             
            )
         | 
| 27 | 
             
            from src.distilabel_dataset_generator.pipelines.sft import (
         | 
| 28 | 
             
                DEFAULT_DATASET_DESCRIPTIONS,
         | 
| 29 | 
            -
                PROMPT_CREATION_PROMPT,
         | 
| 30 | 
             
                generate_pipeline_code,
         | 
| 31 | 
             
                get_magpie_generator,
         | 
| 32 | 
             
                get_prompt_generator,
         | 
| @@ -36,7 +33,7 @@ from src.distilabel_dataset_generator.utils import ( | |
| 36 | 
             
                _LOGGED_OUT_CSS,
         | 
| 37 | 
             
                get_argilla_client,
         | 
| 38 | 
             
                get_org_dropdown,
         | 
| 39 | 
            -
                 | 
| 40 | 
             
            )
         | 
| 41 |  | 
| 42 |  | 
| @@ -55,35 +52,33 @@ def convert_dataframe_messages(dataframe: pd.DataFrame) -> pd.DataFrame: | |
| 55 | 
             
                return dataframe
         | 
| 56 |  | 
| 57 |  | 
| 58 | 
            -
            def generate_system_prompt(dataset_description, progress=gr.Progress()):
         | 
| 59 | 
             
                progress(0.0, desc="Generating system prompt")
         | 
| 60 | 
            -
             | 
| 61 | 
             
                progress(0.3, desc="Initializing text generation")
         | 
| 62 | 
            -
                generate_description = get_prompt_generator()
         | 
| 63 | 
             
                progress(0.7, desc="Generating system prompt")
         | 
| 64 | 
             
                result = next(
         | 
| 65 | 
             
                    generate_description.process(
         | 
| 66 | 
             
                        [
         | 
| 67 | 
             
                            {
         | 
| 68 | 
            -
                                "system_prompt": PROMPT_CREATION_PROMPT,
         | 
| 69 | 
             
                                "instruction": dataset_description,
         | 
| 70 | 
             
                            }
         | 
| 71 | 
             
                        ]
         | 
| 72 | 
             
                    )
         | 
| 73 | 
             
                )[0]["generation"]
         | 
| 74 | 
             
                progress(1.0, desc="System prompt generated")
         | 
| 75 | 
            -
                return result | 
| 76 |  | 
| 77 |  | 
| 78 | 
            -
            def generate_sample_dataset(system_prompt, progress=gr.Progress()):
         | 
| 79 | 
            -
                 | 
| 80 | 
             
                    system_prompt=system_prompt,
         | 
| 81 | 
            -
                    num_turns= | 
| 82 | 
             
                    num_rows=10,
         | 
| 83 | 
             
                    progress=progress,
         | 
| 84 | 
             
                    is_sample=True,
         | 
| 85 | 
             
                )
         | 
| 86 | 
            -
                return  | 
| 87 |  | 
| 88 |  | 
| 89 | 
             
            def generate_dataset(
         | 
| @@ -94,10 +89,8 @@ def generate_dataset( | |
| 94 | 
             
                progress=gr.Progress(),
         | 
| 95 | 
             
            ) -> pd.DataFrame:
         | 
| 96 | 
             
                progress(0.0, desc="(1/2) Generating instructions")
         | 
| 97 | 
            -
                magpie_generator = get_magpie_generator(
         | 
| 98 | 
            -
             | 
| 99 | 
            -
                )
         | 
| 100 | 
            -
                response_generator = get_response_generator(num_turns, system_prompt, is_sample)
         | 
| 101 | 
             
                total_steps: int = num_rows * 2
         | 
| 102 | 
             
                batch_size = DEFAULT_BATCH_SIZE
         | 
| 103 |  | 
| @@ -209,12 +202,12 @@ def push_dataset_to_hub(dataframe, org_name, repo_name, oauth_token, private): | |
| 209 | 
             
                return original_dataframe
         | 
| 210 |  | 
| 211 |  | 
| 212 | 
            -
            def  | 
| 213 | 
             
                org_name: str,
         | 
| 214 | 
             
                repo_name: str,
         | 
| 215 | 
             
                system_prompt: str,
         | 
| 216 | 
             
                num_turns: int = 1,
         | 
| 217 | 
            -
                 | 
| 218 | 
             
                private: bool = False,
         | 
| 219 | 
             
                oauth_token: Union[gr.OAuthToken, None] = None,
         | 
| 220 | 
             
                progress=gr.Progress(),
         | 
| @@ -222,7 +215,7 @@ def push_dataset_to_argilla( | |
| 222 | 
             
                dataframe = generate_dataset(
         | 
| 223 | 
             
                    system_prompt=system_prompt,
         | 
| 224 | 
             
                    num_turns=num_turns,
         | 
| 225 | 
            -
                    num_rows= | 
| 226 | 
             
                )
         | 
| 227 | 
             
                push_dataset_to_hub(dataframe, org_name, repo_name, oauth_token, private)
         | 
| 228 | 
             
                try:
         | 
| @@ -344,29 +337,54 @@ def push_dataset_to_argilla( | |
| 344 | 
             
                return ""
         | 
| 345 |  | 
| 346 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 347 | 
             
            with gr.Blocks(css=_LOGGED_OUT_CSS) as app:
         | 
| 348 | 
             
                with gr.Column() as main_ui:
         | 
| 349 | 
             
                    gr.Markdown(value="## 1. Describe the dataset you want")
         | 
| 350 | 
             
                    with gr.Row():
         | 
| 351 | 
            -
                        with gr.Column(scale= | 
| 352 | 
             
                            dataset_description = gr.Textbox(
         | 
| 353 | 
             
                                label="Dataset description",
         | 
| 354 | 
             
                                placeholder="Give a precise description of your desired dataset.",
         | 
| 355 | 
             
                            )
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 356 | 
             
                            examples = gr.Examples(
         | 
| 357 | 
             
                                examples=DEFAULT_DATASET_DESCRIPTIONS,
         | 
| 358 | 
             
                                inputs=[dataset_description],
         | 
| 359 | 
             
                                cache_examples=False,
         | 
| 360 | 
            -
                                label=" | 
| 361 | 
             
                            )
         | 
| 362 | 
            -
             | 
| 363 | 
            -
                            load_btn = gr.Button("Load dataset", variant="primary")
         | 
| 364 | 
            -
                        with gr.Column(scale=3):
         | 
| 365 | 
             
                            pass
         | 
| 366 |  | 
| 367 | 
             
                    gr.HTML(value="<hr>")
         | 
| 368 | 
            -
                    gr.Markdown(value="## 2. Configure your  | 
| 369 | 
            -
                    with gr.Row():
         | 
| 370 | 
             
                        with gr.Column(scale=1):
         | 
| 371 | 
             
                            system_prompt = gr.Textbox(
         | 
| 372 | 
             
                                label="System prompt",
         | 
| @@ -381,14 +399,21 @@ with gr.Blocks(css=_LOGGED_OUT_CSS) as app: | |
| 381 | 
             
                                interactive=True,
         | 
| 382 | 
             
                                info="Choose between 1 (single turn with 'instruction-response' columns) and 2-4 (multi-turn conversation with a 'messages' column).",
         | 
| 383 | 
             
                            )
         | 
| 384 | 
            -
                            btn_apply_to_sample_dataset = gr.Button( | 
|  | |
|  | |
| 385 | 
             
                        with gr.Column(scale=3):
         | 
| 386 | 
            -
                            dataframe = gr.Dataframe( | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 387 |  | 
| 388 | 
             
                    gr.HTML(value="<hr>")
         | 
| 389 | 
             
                    gr.Markdown(value="## 3. Generate your dataset")
         | 
| 390 | 
            -
                    with gr.Row():
         | 
| 391 | 
            -
                        with gr.Column(scale= | 
| 392 | 
             
                            org_name = get_org_dropdown()
         | 
| 393 | 
             
                            repo_name = gr.Textbox(
         | 
| 394 | 
             
                                label="Repo name",
         | 
| @@ -396,7 +421,7 @@ with gr.Blocks(css=_LOGGED_OUT_CSS) as app: | |
| 396 | 
             
                                value=f"my-distiset-{str(uuid.uuid4())[:8]}",
         | 
| 397 | 
             
                                interactive=True,
         | 
| 398 | 
             
                            )
         | 
| 399 | 
            -
                             | 
| 400 | 
             
                                label="Number of rows",
         | 
| 401 | 
             
                                value=10,
         | 
| 402 | 
             
                                interactive=True,
         | 
| @@ -410,21 +435,38 @@ with gr.Blocks(css=_LOGGED_OUT_CSS) as app: | |
| 410 | 
             
                            )
         | 
| 411 | 
             
                            btn_push_to_hub = gr.Button("Push to Hub", variant="primary", scale=2)
         | 
| 412 | 
             
                        with gr.Column(scale=3):
         | 
| 413 | 
            -
                            success_message = gr.Markdown()
         | 
| 414 | 
            -
             | 
| 415 | 
            -
             | 
| 416 | 
            -
             | 
| 417 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 418 |  | 
| 419 | 
            -
                 | 
| 420 | 
            -
                    triggers=[load_btn.click, btn_apply_to_sample_dataset.click],
         | 
| 421 | 
             
                    fn=generate_system_prompt,
         | 
| 422 | 
            -
                    inputs=[dataset_description],
         | 
| 423 | 
            -
                    outputs=[system_prompt | 
| 424 | 
             
                    show_progress=True,
         | 
| 425 | 
             
                ).then(
         | 
| 426 | 
             
                    fn=generate_sample_dataset,
         | 
| 427 | 
            -
                    inputs=[system_prompt],
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 428 | 
             
                    outputs=[dataframe],
         | 
| 429 | 
             
                    show_progress=True,
         | 
| 430 | 
             
                )
         | 
| @@ -444,21 +486,34 @@ with gr.Blocks(css=_LOGGED_OUT_CSS) as app: | |
| 444 | 
             
                    outputs=[success_message],
         | 
| 445 | 
             
                    show_progress=True,
         | 
| 446 | 
             
                ).success(
         | 
| 447 | 
            -
                    fn= | 
|  | |
|  | |
|  | |
|  | |
| 448 | 
             
                    inputs=[
         | 
| 449 | 
             
                        org_name,
         | 
| 450 | 
             
                        repo_name,
         | 
| 451 | 
             
                        system_prompt,
         | 
| 452 | 
             
                        num_turns,
         | 
| 453 | 
            -
                         | 
| 454 | 
             
                        private,
         | 
| 455 | 
             
                    ],
         | 
| 456 | 
             
                    outputs=[success_message],
         | 
| 457 | 
             
                    show_progress=True,
         | 
| 458 | 
             
                ).success(
         | 
| 459 | 
            -
                    fn= | 
| 460 | 
             
                    inputs=[org_name, repo_name],
         | 
| 461 | 
             
                    outputs=[success_message],
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 462 | 
             
                )
         | 
| 463 | 
            -
             | 
|  | |
| 464 | 
             
                app.load(fn=get_org_dropdown, outputs=[org_name])
         | 
|  | |
| 10 | 
             
            from huggingface_hub import HfApi
         | 
| 11 |  | 
| 12 | 
             
            from src.distilabel_dataset_generator.apps.base import (
         | 
|  | |
|  | |
| 13 | 
             
                hide_success_message,
         | 
| 14 | 
            +
                show_success_message,
         | 
| 15 | 
             
                validate_argilla_user_workspace_dataset,
         | 
| 16 | 
             
                validate_push_to_hub,
         | 
| 17 | 
             
            )
         | 
|  | |
| 24 | 
             
            )
         | 
| 25 | 
             
            from src.distilabel_dataset_generator.pipelines.sft import (
         | 
| 26 | 
             
                DEFAULT_DATASET_DESCRIPTIONS,
         | 
|  | |
| 27 | 
             
                generate_pipeline_code,
         | 
| 28 | 
             
                get_magpie_generator,
         | 
| 29 | 
             
                get_prompt_generator,
         | 
|  | |
| 33 | 
             
                _LOGGED_OUT_CSS,
         | 
| 34 | 
             
                get_argilla_client,
         | 
| 35 | 
             
                get_org_dropdown,
         | 
| 36 | 
            +
                swap_visibility,
         | 
| 37 | 
             
            )
         | 
| 38 |  | 
| 39 |  | 
|  | |
| 52 | 
             
                return dataframe
         | 
| 53 |  | 
| 54 |  | 
| 55 | 
            +
            def generate_system_prompt(dataset_description, temperature, progress=gr.Progress()):
         | 
| 56 | 
             
                progress(0.0, desc="Generating system prompt")
         | 
|  | |
| 57 | 
             
                progress(0.3, desc="Initializing text generation")
         | 
| 58 | 
            +
                generate_description = get_prompt_generator(temperature)
         | 
| 59 | 
             
                progress(0.7, desc="Generating system prompt")
         | 
| 60 | 
             
                result = next(
         | 
| 61 | 
             
                    generate_description.process(
         | 
| 62 | 
             
                        [
         | 
| 63 | 
             
                            {
         | 
|  | |
| 64 | 
             
                                "instruction": dataset_description,
         | 
| 65 | 
             
                            }
         | 
| 66 | 
             
                        ]
         | 
| 67 | 
             
                    )
         | 
| 68 | 
             
                )[0]["generation"]
         | 
| 69 | 
             
                progress(1.0, desc="System prompt generated")
         | 
| 70 | 
            +
                return result
         | 
| 71 |  | 
| 72 |  | 
| 73 | 
            +
            def generate_sample_dataset(system_prompt, num_turns, progress=gr.Progress()):
         | 
| 74 | 
            +
                dataframe = generate_dataset(
         | 
| 75 | 
             
                    system_prompt=system_prompt,
         | 
| 76 | 
            +
                    num_turns=num_turns,
         | 
| 77 | 
             
                    num_rows=10,
         | 
| 78 | 
             
                    progress=progress,
         | 
| 79 | 
             
                    is_sample=True,
         | 
| 80 | 
             
                )
         | 
| 81 | 
            +
                return dataframe
         | 
| 82 |  | 
| 83 |  | 
| 84 | 
             
            def generate_dataset(
         | 
|  | |
| 89 | 
             
                progress=gr.Progress(),
         | 
| 90 | 
             
            ) -> pd.DataFrame:
         | 
| 91 | 
             
                progress(0.0, desc="(1/2) Generating instructions")
         | 
| 92 | 
            +
                magpie_generator = get_magpie_generator(system_prompt, num_turns, is_sample)
         | 
| 93 | 
            +
                response_generator = get_response_generator(system_prompt, num_turns, is_sample)
         | 
|  | |
|  | |
| 94 | 
             
                total_steps: int = num_rows * 2
         | 
| 95 | 
             
                batch_size = DEFAULT_BATCH_SIZE
         | 
| 96 |  | 
|  | |
| 202 | 
             
                return original_dataframe
         | 
| 203 |  | 
| 204 |  | 
| 205 | 
            +
            def push_dataset(
         | 
| 206 | 
             
                org_name: str,
         | 
| 207 | 
             
                repo_name: str,
         | 
| 208 | 
             
                system_prompt: str,
         | 
| 209 | 
             
                num_turns: int = 1,
         | 
| 210 | 
            +
                num_rows: int = 10,
         | 
| 211 | 
             
                private: bool = False,
         | 
| 212 | 
             
                oauth_token: Union[gr.OAuthToken, None] = None,
         | 
| 213 | 
             
                progress=gr.Progress(),
         | 
|  | |
| 215 | 
             
                dataframe = generate_dataset(
         | 
| 216 | 
             
                    system_prompt=system_prompt,
         | 
| 217 | 
             
                    num_turns=num_turns,
         | 
| 218 | 
            +
                    num_rows=num_rows,
         | 
| 219 | 
             
                )
         | 
| 220 | 
             
                push_dataset_to_hub(dataframe, org_name, repo_name, oauth_token, private)
         | 
| 221 | 
             
                try:
         | 
|  | |
| 337 | 
             
                return ""
         | 
| 338 |  | 
| 339 |  | 
| 340 | 
            +
            def show_pipeline_code_visibility():
         | 
| 341 | 
            +
                return {pipeline_code_ui: gr.Accordion(visible=True)}
         | 
| 342 | 
            +
             | 
| 343 | 
            +
             | 
| 344 | 
            +
            def hide_pipeline_code_visibility():
         | 
| 345 | 
            +
                return {pipeline_code_ui: gr.Accordion(visible=False)}
         | 
| 346 | 
            +
             | 
| 347 | 
            +
             | 
| 348 | 
            +
            ######################
         | 
| 349 | 
            +
            # Gradio UI
         | 
| 350 | 
            +
            ######################
         | 
| 351 | 
            +
             | 
| 352 | 
            +
             | 
| 353 | 
             
            with gr.Blocks(css=_LOGGED_OUT_CSS) as app:
         | 
| 354 | 
             
                with gr.Column() as main_ui:
         | 
| 355 | 
             
                    gr.Markdown(value="## 1. Describe the dataset you want")
         | 
| 356 | 
             
                    with gr.Row():
         | 
| 357 | 
            +
                        with gr.Column(scale=2):
         | 
| 358 | 
             
                            dataset_description = gr.Textbox(
         | 
| 359 | 
             
                                label="Dataset description",
         | 
| 360 | 
             
                                placeholder="Give a precise description of your desired dataset.",
         | 
| 361 | 
             
                            )
         | 
| 362 | 
            +
                            with gr.Accordion("Temperature", open=False):
         | 
| 363 | 
            +
                                temperature = gr.Slider(
         | 
| 364 | 
            +
                                    minimum=0.1,
         | 
| 365 | 
            +
                                    maximum=1,
         | 
| 366 | 
            +
                                    value=0.8,
         | 
| 367 | 
            +
                                    step=0.1,
         | 
| 368 | 
            +
                                    interactive=True,
         | 
| 369 | 
            +
                                    show_label=False,
         | 
| 370 | 
            +
                                )
         | 
| 371 | 
            +
                            load_btn = gr.Button(
         | 
| 372 | 
            +
                                "Create dataset",
         | 
| 373 | 
            +
                                variant="primary",
         | 
| 374 | 
            +
                            )
         | 
| 375 | 
            +
                        with gr.Column(scale=2):
         | 
| 376 | 
             
                            examples = gr.Examples(
         | 
| 377 | 
             
                                examples=DEFAULT_DATASET_DESCRIPTIONS,
         | 
| 378 | 
             
                                inputs=[dataset_description],
         | 
| 379 | 
             
                                cache_examples=False,
         | 
| 380 | 
            +
                                label="Examples",
         | 
| 381 | 
             
                            )
         | 
| 382 | 
            +
                        with gr.Column(scale=1):
         | 
|  | |
|  | |
| 383 | 
             
                            pass
         | 
| 384 |  | 
| 385 | 
             
                    gr.HTML(value="<hr>")
         | 
| 386 | 
            +
                    gr.Markdown(value="## 2. Configure your dataset")
         | 
| 387 | 
            +
                    with gr.Row(equal_height=False):
         | 
| 388 | 
             
                        with gr.Column(scale=1):
         | 
| 389 | 
             
                            system_prompt = gr.Textbox(
         | 
| 390 | 
             
                                label="System prompt",
         | 
|  | |
| 399 | 
             
                                interactive=True,
         | 
| 400 | 
             
                                info="Choose between 1 (single turn with 'instruction-response' columns) and 2-4 (multi-turn conversation with a 'messages' column).",
         | 
| 401 | 
             
                            )
         | 
| 402 | 
            +
                            btn_apply_to_sample_dataset = gr.Button(
         | 
| 403 | 
            +
                                "Refresh dataset", variant="secondary", size="sm"
         | 
| 404 | 
            +
                            )
         | 
| 405 | 
             
                        with gr.Column(scale=3):
         | 
| 406 | 
            +
                            dataframe = gr.Dataframe(
         | 
| 407 | 
            +
                                headers=["prompt", "completion"],
         | 
| 408 | 
            +
                                wrap=True,
         | 
| 409 | 
            +
                                height=500,
         | 
| 410 | 
            +
                                interactive=False,
         | 
| 411 | 
            +
                            )
         | 
| 412 |  | 
| 413 | 
             
                    gr.HTML(value="<hr>")
         | 
| 414 | 
             
                    gr.Markdown(value="## 3. Generate your dataset")
         | 
| 415 | 
            +
                    with gr.Row(equal_height=False):
         | 
| 416 | 
            +
                        with gr.Column(scale=2):
         | 
| 417 | 
             
                            org_name = get_org_dropdown()
         | 
| 418 | 
             
                            repo_name = gr.Textbox(
         | 
| 419 | 
             
                                label="Repo name",
         | 
|  | |
| 421 | 
             
                                value=f"my-distiset-{str(uuid.uuid4())[:8]}",
         | 
| 422 | 
             
                                interactive=True,
         | 
| 423 | 
             
                            )
         | 
| 424 | 
            +
                            num_rows = gr.Number(
         | 
| 425 | 
             
                                label="Number of rows",
         | 
| 426 | 
             
                                value=10,
         | 
| 427 | 
             
                                interactive=True,
         | 
|  | |
| 435 | 
             
                            )
         | 
| 436 | 
             
                            btn_push_to_hub = gr.Button("Push to Hub", variant="primary", scale=2)
         | 
| 437 | 
             
                        with gr.Column(scale=3):
         | 
| 438 | 
            +
                            success_message = gr.Markdown(visible=True)
         | 
| 439 | 
            +
                            with gr.Accordion(
         | 
| 440 | 
            +
                                "Do you want to go further? Customize and run with Distilabel",
         | 
| 441 | 
            +
                                open=False,
         | 
| 442 | 
            +
                                visible=False,
         | 
| 443 | 
            +
                            ) as pipeline_code_ui:
         | 
| 444 | 
            +
                                code = generate_pipeline_code(
         | 
| 445 | 
            +
                                    system_prompt=system_prompt.value,
         | 
| 446 | 
            +
                                    num_turns=num_turns.value,
         | 
| 447 | 
            +
                                    num_rows=num_rows.value,
         | 
| 448 | 
            +
                                )
         | 
| 449 | 
            +
                                pipeline_code = gr.Code(
         | 
| 450 | 
            +
                                    value=code,
         | 
| 451 | 
            +
                                    language="python",
         | 
| 452 | 
            +
                                    label="Distilabel Pipeline Code",
         | 
| 453 | 
            +
                                )
         | 
| 454 |  | 
| 455 | 
            +
                load_btn.click(
         | 
|  | |
| 456 | 
             
                    fn=generate_system_prompt,
         | 
| 457 | 
            +
                    inputs=[dataset_description, temperature],
         | 
| 458 | 
            +
                    outputs=[system_prompt],
         | 
| 459 | 
             
                    show_progress=True,
         | 
| 460 | 
             
                ).then(
         | 
| 461 | 
             
                    fn=generate_sample_dataset,
         | 
| 462 | 
            +
                    inputs=[system_prompt, num_turns],
         | 
| 463 | 
            +
                    outputs=[dataframe],
         | 
| 464 | 
            +
                    show_progress=True,
         | 
| 465 | 
            +
                )
         | 
| 466 | 
            +
             | 
| 467 | 
            +
                btn_apply_to_sample_dataset.click(
         | 
| 468 | 
            +
                    fn=generate_sample_dataset,
         | 
| 469 | 
            +
                    inputs=[system_prompt, num_turns],
         | 
| 470 | 
             
                    outputs=[dataframe],
         | 
| 471 | 
             
                    show_progress=True,
         | 
| 472 | 
             
                )
         | 
|  | |
| 486 | 
             
                    outputs=[success_message],
         | 
| 487 | 
             
                    show_progress=True,
         | 
| 488 | 
             
                ).success(
         | 
| 489 | 
            +
                    fn=hide_pipeline_code_visibility,
         | 
| 490 | 
            +
                    inputs=[],
         | 
| 491 | 
            +
                    outputs=[pipeline_code_ui],
         | 
| 492 | 
            +
                ).success(
         | 
| 493 | 
            +
                    fn=push_dataset,
         | 
| 494 | 
             
                    inputs=[
         | 
| 495 | 
             
                        org_name,
         | 
| 496 | 
             
                        repo_name,
         | 
| 497 | 
             
                        system_prompt,
         | 
| 498 | 
             
                        num_turns,
         | 
| 499 | 
            +
                        num_rows,
         | 
| 500 | 
             
                        private,
         | 
| 501 | 
             
                    ],
         | 
| 502 | 
             
                    outputs=[success_message],
         | 
| 503 | 
             
                    show_progress=True,
         | 
| 504 | 
             
                ).success(
         | 
| 505 | 
            +
                    fn=show_success_message,
         | 
| 506 | 
             
                    inputs=[org_name, repo_name],
         | 
| 507 | 
             
                    outputs=[success_message],
         | 
| 508 | 
            +
                ).success(
         | 
| 509 | 
            +
                    fn=generate_pipeline_code,
         | 
| 510 | 
            +
                    inputs=[system_prompt, num_turns, num_rows],
         | 
| 511 | 
            +
                    outputs=[pipeline_code],
         | 
| 512 | 
            +
                ).success(
         | 
| 513 | 
            +
                    fn=show_pipeline_code_visibility,
         | 
| 514 | 
            +
                    inputs=[],
         | 
| 515 | 
            +
                    outputs=[pipeline_code_ui],
         | 
| 516 | 
             
                )
         | 
| 517 | 
            +
             | 
| 518 | 
            +
                app.load(fn=swap_visibility, outputs=main_ui)
         | 
| 519 | 
             
                app.load(fn=get_org_dropdown, outputs=[org_name])
         | 
| @@ -1,4 +1,4 @@ | |
| 1 | 
            -
            import  | 
| 2 | 
             
            import uuid
         | 
| 3 | 
             
            from typing import List, Union
         | 
| 4 |  | 
| @@ -10,10 +10,8 @@ from distilabel.distiset import Distiset | |
| 10 | 
             
            from huggingface_hub import HfApi
         | 
| 11 |  | 
| 12 | 
             
            from src.distilabel_dataset_generator.apps.base import (
         | 
| 13 | 
            -
                get_argilla_client,
         | 
| 14 | 
            -
                get_pipeline_code_ui,
         | 
| 15 | 
             
                hide_success_message,
         | 
| 16 | 
            -
                 | 
| 17 | 
             
                validate_argilla_user_workspace_dataset,
         | 
| 18 | 
             
                validate_push_to_hub,
         | 
| 19 | 
             
            )
         | 
| @@ -26,7 +24,6 @@ from src.distilabel_dataset_generator.pipelines.embeddings import ( | |
| 26 | 
             
            )
         | 
| 27 | 
             
            from src.distilabel_dataset_generator.pipelines.textcat import (
         | 
| 28 | 
             
                DEFAULT_DATASET_DESCRIPTIONS,
         | 
| 29 | 
            -
                PROMPT_CREATION_PROMPT,
         | 
| 30 | 
             
                generate_pipeline_code,
         | 
| 31 | 
             
                get_labeller_generator,
         | 
| 32 | 
             
                get_prompt_generator,
         | 
| @@ -37,45 +34,42 @@ from src.distilabel_dataset_generator.utils import ( | |
| 37 | 
             
                get_argilla_client,
         | 
| 38 | 
             
                get_org_dropdown,
         | 
| 39 | 
             
                get_preprocess_labels,
         | 
| 40 | 
            -
                 | 
| 41 | 
             
            )
         | 
| 42 |  | 
| 43 |  | 
| 44 | 
            -
            def generate_system_prompt(dataset_description, progress=gr.Progress()):
         | 
| 45 | 
             
                progress(0.0, desc="Generating text classification task")
         | 
| 46 | 
             
                progress(0.3, desc="Initializing text generation")
         | 
| 47 | 
            -
                generate_description = get_prompt_generator()
         | 
| 48 | 
             
                progress(0.7, desc="Generating text classification task")
         | 
| 49 | 
            -
                 | 
| 50 | 
             
                    generate_description.process(
         | 
| 51 | 
             
                        [
         | 
| 52 | 
             
                            {
         | 
| 53 | 
            -
                                "system_prompt": PROMPT_CREATION_PROMPT,
         | 
| 54 | 
             
                                "instruction": dataset_description,
         | 
| 55 | 
             
                            }
         | 
| 56 | 
             
                        ]
         | 
| 57 | 
             
                    )
         | 
| 58 | 
             
                )[0]["generation"]
         | 
| 59 | 
             
                progress(1.0, desc="Text classification task generated")
         | 
| 60 | 
            -
                 | 
| 61 | 
            -
             | 
|  | |
|  | |
| 62 |  | 
| 63 | 
            -
            def generate_sample_dataset(system_prompt, progress=gr.Progress()):
         | 
| 64 | 
            -
                 | 
| 65 | 
             
                    system_prompt=system_prompt,
         | 
| 66 | 
            -
                    difficulty= | 
| 67 | 
            -
                    clarity= | 
| 68 | 
            -
                    labels= | 
| 69 | 
            -
                    num_labels= | 
| 70 | 
             
                    num_rows=10,
         | 
| 71 | 
             
                    progress=progress,
         | 
| 72 | 
             
                    is_sample=True,
         | 
| 73 | 
             
                )
         | 
| 74 | 
            -
                 | 
| 75 | 
            -
                    df = df[["label", "text"]]
         | 
| 76 | 
            -
                elif "labels" in df.columns:
         | 
| 77 | 
            -
                    df = df[["labels", "text"]]
         | 
| 78 | 
            -
                return df
         | 
| 79 |  | 
| 80 |  | 
| 81 | 
             
            def generate_dataset(
         | 
| @@ -88,17 +82,13 @@ def generate_dataset( | |
| 88 | 
             
                is_sample: bool = False,
         | 
| 89 | 
             
                progress=gr.Progress(),
         | 
| 90 | 
             
            ) -> pd.DataFrame:
         | 
| 91 | 
            -
                if is_sample:
         | 
| 92 | 
            -
                    multiplier = 1
         | 
| 93 | 
            -
                else:
         | 
| 94 | 
            -
                    multiplier = 2
         | 
| 95 | 
             
                progress(0.0, desc="(1/2) Generating text classification data")
         | 
| 96 | 
             
                labels = get_preprocess_labels(labels)
         | 
| 97 | 
             
                textcat_generator = get_textcat_generator(
         | 
| 98 | 
             
                    difficulty=difficulty, clarity=clarity, is_sample=is_sample
         | 
| 99 | 
             
                )
         | 
| 100 | 
             
                labeller_generator = get_labeller_generator(
         | 
| 101 | 
            -
                    system_prompt=system_prompt,
         | 
| 102 | 
             
                    labels=labels,
         | 
| 103 | 
             
                    num_labels=num_labels,
         | 
| 104 | 
             
                )
         | 
| @@ -110,13 +100,15 @@ def generate_dataset( | |
| 110 | 
             
                textcat_results = []
         | 
| 111 | 
             
                while n_processed < num_rows:
         | 
| 112 | 
             
                    progress(
         | 
| 113 | 
            -
                         | 
| 114 | 
             
                        total=total_steps,
         | 
| 115 | 
             
                        desc="(1/2) Generating text classification data",
         | 
| 116 | 
             
                    )
         | 
| 117 | 
             
                    remaining_rows = num_rows - n_processed
         | 
| 118 | 
             
                    batch_size = min(batch_size, remaining_rows)
         | 
| 119 | 
            -
                    inputs = [ | 
|  | |
|  | |
| 120 | 
             
                    batch = list(textcat_generator.process(inputs=inputs))
         | 
| 121 | 
             
                    textcat_results.extend(batch[0])
         | 
| 122 | 
             
                    n_processed += batch_size
         | 
| @@ -124,58 +116,41 @@ def generate_dataset( | |
| 124 | 
             
                    result["text"] = result["input_text"]
         | 
| 125 |  | 
| 126 | 
             
                # label text classification data
         | 
| 127 | 
            -
                progress( | 
| 128 | 
            -
                 | 
| 129 | 
            -
             | 
| 130 | 
            -
             | 
| 131 | 
            -
                    while n_processed < num_rows:
         | 
| 132 | 
            -
                        progress(
         | 
| 133 | 
            -
                            0.5 + 0.5 * n_processed / num_rows,
         | 
| 134 | 
            -
                            total=total_steps,
         | 
| 135 | 
            -
                            desc="(1/2) Labeling text classification data",
         | 
| 136 | 
            -
                        )
         | 
| 137 | 
            -
                        batch = textcat_results[n_processed : n_processed + batch_size]
         | 
| 138 | 
            -
                        labels_batch = list(labeller_generator.process(inputs=batch))
         | 
| 139 | 
            -
                        labeller_results.extend(labels_batch[0])
         | 
| 140 | 
            -
                        n_processed += batch_size
         | 
| 141 | 
             
                    progress(
         | 
| 142 | 
            -
                         | 
| 143 | 
             
                        total=total_steps,
         | 
| 144 | 
            -
                        desc="( | 
| 145 | 
             
                    )
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 146 |  | 
| 147 | 
             
                # create final dataset
         | 
| 148 | 
             
                distiset_results = []
         | 
| 149 | 
            -
                 | 
| 150 | 
            -
                for result in source_results:
         | 
| 151 | 
             
                    record = {
         | 
| 152 | 
             
                        key: result[key]
         | 
| 153 | 
            -
                        for key in [" | 
| 154 | 
             
                        if key in result
         | 
| 155 | 
             
                    }
         | 
| 156 | 
             
                    distiset_results.append(record)
         | 
| 157 |  | 
| 158 | 
             
                dataframe = pd.DataFrame(distiset_results)
         | 
| 159 | 
            -
                if  | 
| 160 | 
            -
                     | 
| 161 | 
            -
             | 
| 162 | 
            -
                         | 
| 163 | 
            -
             | 
| 164 | 
            -
                        )
         | 
| 165 | 
            -
                    else:
         | 
| 166 | 
            -
                        dataframe["labels"] = dataframe["labels"].apply(
         | 
| 167 | 
            -
                            lambda x: (
         | 
| 168 | 
            -
                                list(
         | 
| 169 | 
            -
                                    set(
         | 
| 170 | 
            -
                                        label.lower().strip()
         | 
| 171 | 
            -
                                        for label in x
         | 
| 172 | 
            -
                                        if label.lower().strip() in labels
         | 
| 173 | 
            -
                                    )
         | 
| 174 | 
            -
                                )
         | 
| 175 | 
            -
                                if isinstance(x, list)
         | 
| 176 | 
            -
                                else None
         | 
| 177 | 
            -
                            )
         | 
| 178 | 
            -
                        )
         | 
| 179 | 
             
                progress(1.0, desc="Dataset generation completed")
         | 
| 180 | 
             
                return dataframe
         | 
| 181 |  | 
| @@ -213,14 +188,14 @@ def push_dataset_to_hub( | |
| 213 | 
             
                )
         | 
| 214 |  | 
| 215 |  | 
| 216 | 
            -
            def  | 
| 217 | 
             
                org_name: str,
         | 
| 218 | 
             
                repo_name: str,
         | 
| 219 | 
             
                system_prompt: str,
         | 
| 220 | 
             
                difficulty: str,
         | 
| 221 | 
             
                clarity: str,
         | 
| 222 | 
             
                num_labels: int = 1,
         | 
| 223 | 
            -
                 | 
| 224 | 
             
                labels: List[str] = None,
         | 
| 225 | 
             
                private: bool = False,
         | 
| 226 | 
             
                oauth_token: Union[gr.OAuthToken, None] = None,
         | 
| @@ -232,7 +207,7 @@ def push_dataset_to_argilla( | |
| 232 | 
             
                    clarity=clarity,
         | 
| 233 | 
             
                    num_labels=num_labels,
         | 
| 234 | 
             
                    labels=labels,
         | 
| 235 | 
            -
                    num_rows= | 
| 236 | 
             
                )
         | 
| 237 | 
             
                push_dataset_to_hub(
         | 
| 238 | 
             
                    dataframe, org_name, repo_name, num_labels, labels, oauth_token, private
         | 
| @@ -283,7 +258,7 @@ def push_dataset_to_argilla( | |
| 283 | 
             
                    )
         | 
| 284 |  | 
| 285 | 
             
                    dataframe["text_length"] = dataframe["text"].apply(len)
         | 
| 286 | 
            -
                    dataframe["text_embeddings"] = get_embeddings(dataframe["text"])
         | 
| 287 |  | 
| 288 | 
             
                    progress(0.5, desc="Creating dataset")
         | 
| 289 | 
             
                    rg_dataset = client.datasets(name=repo_name, workspace=hf_user)
         | 
| @@ -332,15 +307,6 @@ def push_dataset_to_argilla( | |
| 332 | 
             
                return ""
         | 
| 333 |  | 
| 334 |  | 
| 335 | 
            -
            def update_suggested_labels(system_prompt):
         | 
| 336 | 
            -
                new_labels = re.findall(r"'(\b[\w-]+\b)'", system_prompt)
         | 
| 337 | 
            -
                if not new_labels:
         | 
| 338 | 
            -
                    return gr.Warning(
         | 
| 339 | 
            -
                        "No labels found in the system prompt. Please add labels manually."
         | 
| 340 | 
            -
                    )
         | 
| 341 | 
            -
                return gr.update(choices=new_labels, value=new_labels)
         | 
| 342 | 
            -
             | 
| 343 | 
            -
             | 
| 344 | 
             
            def validate_input_labels(labels):
         | 
| 345 | 
             
                if not labels or len(labels) < 2:
         | 
| 346 | 
             
                    raise gr.Error(
         | 
| @@ -353,44 +319,74 @@ def update_max_num_labels(labels): | |
| 353 | 
             
                return gr.update(maximum=len(labels) if labels else 1)
         | 
| 354 |  | 
| 355 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 356 | 
             
            with gr.Blocks(css=_LOGGED_OUT_CSS) as app:
         | 
| 357 | 
             
                with gr.Column() as main_ui:
         | 
| 358 | 
             
                    gr.Markdown("## 1. Describe the dataset you want")
         | 
| 359 | 
             
                    with gr.Row():
         | 
| 360 | 
            -
                        with gr.Column(scale= | 
| 361 | 
             
                            dataset_description = gr.Textbox(
         | 
| 362 | 
             
                                label="Dataset description",
         | 
| 363 | 
             
                                placeholder="Give a precise description of your desired dataset.",
         | 
| 364 | 
             
                            )
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 365 | 
             
                            examples = gr.Examples(
         | 
| 366 | 
             
                                examples=DEFAULT_DATASET_DESCRIPTIONS,
         | 
| 367 | 
             
                                inputs=[dataset_description],
         | 
| 368 | 
             
                                cache_examples=False,
         | 
| 369 | 
            -
                                label=" | 
| 370 | 
             
                            )
         | 
| 371 | 
            -
             | 
| 372 | 
            -
                        with gr.Column(scale=3):
         | 
| 373 | 
             
                            pass
         | 
| 374 |  | 
| 375 | 
             
                    gr.HTML("<hr>")
         | 
| 376 | 
            -
                    gr.Markdown("## 2. Configure your  | 
| 377 | 
            -
                    with gr.Row():
         | 
| 378 | 
             
                        with gr.Column(scale=1):
         | 
| 379 | 
             
                            system_prompt = gr.Textbox(
         | 
| 380 | 
             
                                label="System prompt",
         | 
| 381 | 
             
                                placeholder="You are a helpful assistant.",
         | 
| 382 | 
             
                                visible=True,
         | 
| 383 | 
             
                            )
         | 
| 384 | 
            -
                             | 
| 385 | 
            -
                                choices=[
         | 
| 386 | 
            -
             | 
| 387 | 
            -
             | 
| 388 | 
            -
             | 
| 389 | 
            -
             | 
| 390 | 
            -
                                 | 
| 391 | 
            -
             | 
| 392 | 
            -
             | 
| 393 | 
            -
                                 | 
|  | |
|  | |
|  | |
|  | |
| 394 | 
             
                                interactive=True,
         | 
| 395 | 
             
                            )
         | 
| 396 | 
             
                            clarity = gr.Dropdown(
         | 
| @@ -408,30 +404,30 @@ with gr.Blocks(css=_LOGGED_OUT_CSS) as app: | |
| 408 | 
             
                                info="Set how easily the correct label or labels can be identified.",
         | 
| 409 | 
             
                                interactive=True,
         | 
| 410 | 
             
                            )
         | 
| 411 | 
            -
                             | 
| 412 | 
            -
                                choices=[ | 
| 413 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 414 | 
             
                                interactive=True,
         | 
| 415 | 
            -
                                label="Labels",
         | 
| 416 | 
            -
                                multiselect=True,
         | 
| 417 | 
            -
                                info="Add the labels to classify the text.",
         | 
| 418 | 
             
                            )
         | 
| 419 | 
            -
                             | 
| 420 | 
            -
                                 | 
| 421 | 
            -
                                value=1,
         | 
| 422 | 
            -
                                minimum=1,
         | 
| 423 | 
            -
                                maximum=10,
         | 
| 424 | 
            -
                                info="Select 1 for single-label and >1 for multi-label.",
         | 
| 425 | 
            -
                                interactive=True,
         | 
| 426 | 
             
                            )
         | 
| 427 | 
            -
                            btn_apply_to_sample_dataset = gr.Button("Refresh dataset")
         | 
| 428 | 
             
                        with gr.Column(scale=3):
         | 
| 429 | 
            -
                            dataframe = gr.Dataframe( | 
|  | |
|  | |
| 430 |  | 
| 431 | 
             
                    gr.HTML("<hr>")
         | 
| 432 | 
             
                    gr.Markdown("## 3. Generate your dataset")
         | 
| 433 | 
            -
                    with gr.Row():
         | 
| 434 | 
            -
                        with gr.Column(scale= | 
| 435 | 
             
                            org_name = get_org_dropdown()
         | 
| 436 | 
             
                            repo_name = gr.Textbox(
         | 
| 437 | 
             
                                label="Repo name",
         | 
| @@ -439,7 +435,7 @@ with gr.Blocks(css=_LOGGED_OUT_CSS) as app: | |
| 439 | 
             
                                value=f"my-distiset-{str(uuid.uuid4())[:8]}",
         | 
| 440 | 
             
                                interactive=True,
         | 
| 441 | 
             
                            )
         | 
| 442 | 
            -
                             | 
| 443 | 
             
                                label="Number of rows",
         | 
| 444 | 
             
                                value=10,
         | 
| 445 | 
             
                                interactive=True,
         | 
| @@ -454,39 +450,54 @@ with gr.Blocks(css=_LOGGED_OUT_CSS) as app: | |
| 454 | 
             
                            btn_push_to_hub = gr.Button("Push to Hub", variant="primary", scale=2)
         | 
| 455 | 
             
                        with gr.Column(scale=3):
         | 
| 456 | 
             
                            success_message = gr.Markdown(visible=True)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 457 |  | 
| 458 | 
            -
             | 
| 459 | 
            -
                        generate_pipeline_code(
         | 
| 460 | 
            -
                            system_prompt.value,
         | 
| 461 | 
            -
                            difficulty=difficulty.value,
         | 
| 462 | 
            -
                            clarity=clarity.value,
         | 
| 463 | 
            -
                            labels=labels.value,
         | 
| 464 | 
            -
                            num_labels=num_labels.value,
         | 
| 465 | 
            -
                            num_rows=n_rows.value,
         | 
| 466 | 
            -
                        )
         | 
| 467 | 
            -
                    )
         | 
| 468 | 
            -
             | 
| 469 | 
            -
                gr.on(
         | 
| 470 | 
            -
                    triggers=[load_btn.click, btn_apply_to_sample_dataset.click],
         | 
| 471 | 
             
                    fn=generate_system_prompt,
         | 
| 472 | 
            -
                    inputs=[dataset_description],
         | 
| 473 | 
            -
                    outputs=[system_prompt,  | 
| 474 | 
             
                    show_progress=True,
         | 
| 475 | 
             
                ).then(
         | 
| 476 | 
             
                    fn=generate_sample_dataset,
         | 
| 477 | 
            -
                    inputs=[system_prompt],
         | 
| 478 | 
             
                    outputs=[dataframe],
         | 
| 479 | 
             
                    show_progress=True,
         | 
| 480 | 
            -
                ).then(
         | 
| 481 | 
            -
                    fn=update_suggested_labels,
         | 
| 482 | 
            -
                    inputs=[system_prompt],
         | 
| 483 | 
            -
                    outputs=labels,
         | 
| 484 | 
             
                ).then(
         | 
| 485 | 
             
                    fn=update_max_num_labels,
         | 
| 486 | 
             
                    inputs=[labels],
         | 
| 487 | 
             
                    outputs=[num_labels],
         | 
| 488 | 
             
                )
         | 
| 489 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 490 | 
             
                btn_push_to_hub.click(
         | 
| 491 | 
             
                    fn=validate_argilla_user_workspace_dataset,
         | 
| 492 | 
             
                    inputs=[repo_name],
         | 
| @@ -502,7 +513,11 @@ with gr.Blocks(css=_LOGGED_OUT_CSS) as app: | |
| 502 | 
             
                    outputs=[success_message],
         | 
| 503 | 
             
                    show_progress=True,
         | 
| 504 | 
             
                ).success(
         | 
| 505 | 
            -
                    fn= | 
|  | |
|  | |
|  | |
|  | |
| 506 | 
             
                    inputs=[
         | 
| 507 | 
             
                        org_name,
         | 
| 508 | 
             
                        repo_name,
         | 
| @@ -510,16 +525,32 @@ with gr.Blocks(css=_LOGGED_OUT_CSS) as app: | |
| 510 | 
             
                        difficulty,
         | 
| 511 | 
             
                        clarity,
         | 
| 512 | 
             
                        num_labels,
         | 
| 513 | 
            -
                         | 
| 514 | 
             
                        labels,
         | 
| 515 | 
             
                        private,
         | 
| 516 | 
             
                    ],
         | 
| 517 | 
             
                    outputs=[success_message],
         | 
| 518 | 
             
                    show_progress=True,
         | 
| 519 | 
             
                ).success(
         | 
| 520 | 
            -
                    fn= | 
| 521 | 
             
                    inputs=[org_name, repo_name],
         | 
| 522 | 
             
                    outputs=[success_message],
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 523 | 
             
                )
         | 
| 524 | 
            -
             | 
|  | |
| 525 | 
             
                app.load(fn=get_org_dropdown, outputs=[org_name])
         | 
|  | |
| 1 | 
            +
            import json
         | 
| 2 | 
             
            import uuid
         | 
| 3 | 
             
            from typing import List, Union
         | 
| 4 |  | 
|  | |
| 10 | 
             
            from huggingface_hub import HfApi
         | 
| 11 |  | 
| 12 | 
             
            from src.distilabel_dataset_generator.apps.base import (
         | 
|  | |
|  | |
| 13 | 
             
                hide_success_message,
         | 
| 14 | 
            +
                show_success_message,
         | 
| 15 | 
             
                validate_argilla_user_workspace_dataset,
         | 
| 16 | 
             
                validate_push_to_hub,
         | 
| 17 | 
             
            )
         | 
|  | |
| 24 | 
             
            )
         | 
| 25 | 
             
            from src.distilabel_dataset_generator.pipelines.textcat import (
         | 
| 26 | 
             
                DEFAULT_DATASET_DESCRIPTIONS,
         | 
|  | |
| 27 | 
             
                generate_pipeline_code,
         | 
| 28 | 
             
                get_labeller_generator,
         | 
| 29 | 
             
                get_prompt_generator,
         | 
|  | |
| 34 | 
             
                get_argilla_client,
         | 
| 35 | 
             
                get_org_dropdown,
         | 
| 36 | 
             
                get_preprocess_labels,
         | 
| 37 | 
            +
                swap_visibility,
         | 
| 38 | 
             
            )
         | 
| 39 |  | 
| 40 |  | 
| 41 | 
            +
            def generate_system_prompt(dataset_description, temperature, progress=gr.Progress()):
         | 
| 42 | 
             
                progress(0.0, desc="Generating text classification task")
         | 
| 43 | 
             
                progress(0.3, desc="Initializing text generation")
         | 
| 44 | 
            +
                generate_description = get_prompt_generator(temperature)
         | 
| 45 | 
             
                progress(0.7, desc="Generating text classification task")
         | 
| 46 | 
            +
                result = next(
         | 
| 47 | 
             
                    generate_description.process(
         | 
| 48 | 
             
                        [
         | 
| 49 | 
             
                            {
         | 
|  | |
| 50 | 
             
                                "instruction": dataset_description,
         | 
| 51 | 
             
                            }
         | 
| 52 | 
             
                        ]
         | 
| 53 | 
             
                    )
         | 
| 54 | 
             
                )[0]["generation"]
         | 
| 55 | 
             
                progress(1.0, desc="Text classification task generated")
         | 
| 56 | 
            +
                data = json.loads(result)
         | 
| 57 | 
            +
                system_prompt = data["classification_task"]
         | 
| 58 | 
            +
                labels = data["labels"]
         | 
| 59 | 
            +
                return system_prompt, labels
         | 
| 60 |  | 
| 61 | 
            +
            def generate_sample_dataset(system_prompt, difficulty, clarity, labels, num_labels, progress=gr.Progress()):
         | 
| 62 | 
            +
                dataframe = generate_dataset(
         | 
| 63 | 
             
                    system_prompt=system_prompt,
         | 
| 64 | 
            +
                    difficulty=difficulty,
         | 
| 65 | 
            +
                    clarity=clarity,
         | 
| 66 | 
            +
                    labels=labels,
         | 
| 67 | 
            +
                    num_labels=num_labels,
         | 
| 68 | 
             
                    num_rows=10,
         | 
| 69 | 
             
                    progress=progress,
         | 
| 70 | 
             
                    is_sample=True,
         | 
| 71 | 
             
                )
         | 
| 72 | 
            +
                return dataframe
         | 
|  | |
|  | |
|  | |
|  | |
| 73 |  | 
| 74 |  | 
| 75 | 
             
            def generate_dataset(
         | 
|  | |
| 82 | 
             
                is_sample: bool = False,
         | 
| 83 | 
             
                progress=gr.Progress(),
         | 
| 84 | 
             
            ) -> pd.DataFrame:
         | 
|  | |
|  | |
|  | |
|  | |
| 85 | 
             
                progress(0.0, desc="(1/2) Generating text classification data")
         | 
| 86 | 
             
                labels = get_preprocess_labels(labels)
         | 
| 87 | 
             
                textcat_generator = get_textcat_generator(
         | 
| 88 | 
             
                    difficulty=difficulty, clarity=clarity, is_sample=is_sample
         | 
| 89 | 
             
                )
         | 
| 90 | 
             
                labeller_generator = get_labeller_generator(
         | 
| 91 | 
            +
                    system_prompt=f"{system_prompt} {', '.join(labels)}",
         | 
| 92 | 
             
                    labels=labels,
         | 
| 93 | 
             
                    num_labels=num_labels,
         | 
| 94 | 
             
                )
         | 
|  | |
| 100 | 
             
                textcat_results = []
         | 
| 101 | 
             
                while n_processed < num_rows:
         | 
| 102 | 
             
                    progress(
         | 
| 103 | 
            +
                        2 * 0.5 * n_processed / num_rows,
         | 
| 104 | 
             
                        total=total_steps,
         | 
| 105 | 
             
                        desc="(1/2) Generating text classification data",
         | 
| 106 | 
             
                    )
         | 
| 107 | 
             
                    remaining_rows = num_rows - n_processed
         | 
| 108 | 
             
                    batch_size = min(batch_size, remaining_rows)
         | 
| 109 | 
            +
                    inputs = [
         | 
| 110 | 
            +
                        {"task": f"{system_prompt} {', '.join(labels)}"} for _ in range(batch_size)
         | 
| 111 | 
            +
                    ]
         | 
| 112 | 
             
                    batch = list(textcat_generator.process(inputs=inputs))
         | 
| 113 | 
             
                    textcat_results.extend(batch[0])
         | 
| 114 | 
             
                    n_processed += batch_size
         | 
|  | |
| 116 | 
             
                    result["text"] = result["input_text"]
         | 
| 117 |  | 
| 118 | 
             
                # label text classification data
         | 
| 119 | 
            +
                progress(2 * 0.5, desc="(1/2) Generating text classification data")
         | 
| 120 | 
            +
                n_processed = 0
         | 
| 121 | 
            +
                labeller_results = []
         | 
| 122 | 
            +
                while n_processed < num_rows:
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 123 | 
             
                    progress(
         | 
| 124 | 
            +
                        0.5 + 0.5 * n_processed / num_rows,
         | 
| 125 | 
             
                        total=total_steps,
         | 
| 126 | 
            +
                        desc="(1/2) Labeling text classification data",
         | 
| 127 | 
             
                    )
         | 
| 128 | 
            +
                    batch = textcat_results[n_processed : n_processed + batch_size]
         | 
| 129 | 
            +
                    labels_batch = list(labeller_generator.process(inputs=batch))
         | 
| 130 | 
            +
                    labeller_results.extend(labels_batch[0])
         | 
| 131 | 
            +
                    n_processed += batch_size
         | 
| 132 | 
            +
                progress(
         | 
| 133 | 
            +
                    1,
         | 
| 134 | 
            +
                    total=total_steps,
         | 
| 135 | 
            +
                    desc="(2/2) Creating dataset",
         | 
| 136 | 
            +
                )
         | 
| 137 |  | 
| 138 | 
             
                # create final dataset
         | 
| 139 | 
             
                distiset_results = []
         | 
| 140 | 
            +
                for result in labeller_results:
         | 
|  | |
| 141 | 
             
                    record = {
         | 
| 142 | 
             
                        key: result[key]
         | 
| 143 | 
            +
                        for key in ["labels", "text"]
         | 
| 144 | 
             
                        if key in result
         | 
| 145 | 
             
                    }
         | 
| 146 | 
             
                    distiset_results.append(record)
         | 
| 147 |  | 
| 148 | 
             
                dataframe = pd.DataFrame(distiset_results)
         | 
| 149 | 
            +
                if num_labels == 1:
         | 
| 150 | 
            +
                    dataframe = dataframe.rename(columns={"labels": "label"})
         | 
| 151 | 
            +
                    dataframe["label"] = dataframe["label"].apply(
         | 
| 152 | 
            +
                        lambda x: x.lower().strip() if x.lower().strip() in labels else None
         | 
| 153 | 
            +
                    )
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 154 | 
             
                progress(1.0, desc="Dataset generation completed")
         | 
| 155 | 
             
                return dataframe
         | 
| 156 |  | 
|  | |
| 188 | 
             
                )
         | 
| 189 |  | 
| 190 |  | 
| 191 | 
            +
            def push_dataset(
         | 
| 192 | 
             
                org_name: str,
         | 
| 193 | 
             
                repo_name: str,
         | 
| 194 | 
             
                system_prompt: str,
         | 
| 195 | 
             
                difficulty: str,
         | 
| 196 | 
             
                clarity: str,
         | 
| 197 | 
             
                num_labels: int = 1,
         | 
| 198 | 
            +
                num_rows: int = 10,
         | 
| 199 | 
             
                labels: List[str] = None,
         | 
| 200 | 
             
                private: bool = False,
         | 
| 201 | 
             
                oauth_token: Union[gr.OAuthToken, None] = None,
         | 
|  | |
| 207 | 
             
                    clarity=clarity,
         | 
| 208 | 
             
                    num_labels=num_labels,
         | 
| 209 | 
             
                    labels=labels,
         | 
| 210 | 
            +
                    num_rows=num_rows,
         | 
| 211 | 
             
                )
         | 
| 212 | 
             
                push_dataset_to_hub(
         | 
| 213 | 
             
                    dataframe, org_name, repo_name, num_labels, labels, oauth_token, private
         | 
|  | |
| 258 | 
             
                    )
         | 
| 259 |  | 
| 260 | 
             
                    dataframe["text_length"] = dataframe["text"].apply(len)
         | 
| 261 | 
            +
                    dataframe["text_embeddings"] = get_embeddings(dataframe["text"].to_list())
         | 
| 262 |  | 
| 263 | 
             
                    progress(0.5, desc="Creating dataset")
         | 
| 264 | 
             
                    rg_dataset = client.datasets(name=repo_name, workspace=hf_user)
         | 
|  | |
| 307 | 
             
                return ""
         | 
| 308 |  | 
| 309 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 310 | 
             
            def validate_input_labels(labels):
         | 
| 311 | 
             
                if not labels or len(labels) < 2:
         | 
| 312 | 
             
                    raise gr.Error(
         | 
|  | |
| 319 | 
             
                return gr.update(maximum=len(labels) if labels else 1)
         | 
| 320 |  | 
| 321 |  | 
| 322 | 
            +
            def show_pipeline_code_visibility():
         | 
| 323 | 
            +
                return {pipeline_code_ui: gr.Accordion(visible=True)}
         | 
| 324 | 
            +
             | 
| 325 | 
            +
             | 
| 326 | 
            +
            def hide_pipeline_code_visibility():
         | 
| 327 | 
            +
                return {pipeline_code_ui: gr.Accordion(visible=False)}
         | 
| 328 | 
            +
             | 
| 329 | 
            +
             | 
| 330 | 
            +
            ######################
         | 
| 331 | 
            +
            # Gradio UI
         | 
| 332 | 
            +
            ######################
         | 
| 333 | 
            +
             | 
| 334 | 
            +
             | 
| 335 | 
             
            with gr.Blocks(css=_LOGGED_OUT_CSS) as app:
         | 
| 336 | 
             
                with gr.Column() as main_ui:
         | 
| 337 | 
             
                    gr.Markdown("## 1. Describe the dataset you want")
         | 
| 338 | 
             
                    with gr.Row():
         | 
| 339 | 
            +
                        with gr.Column(scale=2):
         | 
| 340 | 
             
                            dataset_description = gr.Textbox(
         | 
| 341 | 
             
                                label="Dataset description",
         | 
| 342 | 
             
                                placeholder="Give a precise description of your desired dataset.",
         | 
| 343 | 
             
                            )
         | 
| 344 | 
            +
                            with gr.Accordion("Temperature", open=False):
         | 
| 345 | 
            +
                                temperature = gr.Slider(
         | 
| 346 | 
            +
                                    minimum=0.1,
         | 
| 347 | 
            +
                                    maximum=1,
         | 
| 348 | 
            +
                                    value=0.8,
         | 
| 349 | 
            +
                                    step=0.1,
         | 
| 350 | 
            +
                                    interactive=True,
         | 
| 351 | 
            +
                                    show_label=False,
         | 
| 352 | 
            +
                                )
         | 
| 353 | 
            +
                            load_btn = gr.Button(
         | 
| 354 | 
            +
                                "Create dataset",
         | 
| 355 | 
            +
                                variant="primary",
         | 
| 356 | 
            +
                            )
         | 
| 357 | 
            +
                        with gr.Column(scale=2):
         | 
| 358 | 
             
                            examples = gr.Examples(
         | 
| 359 | 
             
                                examples=DEFAULT_DATASET_DESCRIPTIONS,
         | 
| 360 | 
             
                                inputs=[dataset_description],
         | 
| 361 | 
             
                                cache_examples=False,
         | 
| 362 | 
            +
                                label="Examples",
         | 
| 363 | 
             
                            )
         | 
| 364 | 
            +
                        with gr.Column(scale=1):
         | 
|  | |
| 365 | 
             
                            pass
         | 
| 366 |  | 
| 367 | 
             
                    gr.HTML("<hr>")
         | 
| 368 | 
            +
                    gr.Markdown("## 2. Configure your dataset")
         | 
| 369 | 
            +
                    with gr.Row(equal_height=False):
         | 
| 370 | 
             
                        with gr.Column(scale=1):
         | 
| 371 | 
             
                            system_prompt = gr.Textbox(
         | 
| 372 | 
             
                                label="System prompt",
         | 
| 373 | 
             
                                placeholder="You are a helpful assistant.",
         | 
| 374 | 
             
                                visible=True,
         | 
| 375 | 
             
                            )
         | 
| 376 | 
            +
                            labels = gr.Dropdown(
         | 
| 377 | 
            +
                                choices=[],
         | 
| 378 | 
            +
                                allow_custom_value=True,
         | 
| 379 | 
            +
                                interactive=True,
         | 
| 380 | 
            +
                                label="Labels",
         | 
| 381 | 
            +
                                multiselect=True,
         | 
| 382 | 
            +
                                info="Add the labels to classify the text.",
         | 
| 383 | 
            +
                            )
         | 
| 384 | 
            +
                            num_labels = gr.Number(
         | 
| 385 | 
            +
                                label="Number of labels per text",
         | 
| 386 | 
            +
                                value=1,
         | 
| 387 | 
            +
                                minimum=1,
         | 
| 388 | 
            +
                                maximum=10,
         | 
| 389 | 
            +
                                info="Select 1 for single-label and >1 for multi-label.",
         | 
| 390 | 
             
                                interactive=True,
         | 
| 391 | 
             
                            )
         | 
| 392 | 
             
                            clarity = gr.Dropdown(
         | 
|  | |
| 404 | 
             
                                info="Set how easily the correct label or labels can be identified.",
         | 
| 405 | 
             
                                interactive=True,
         | 
| 406 | 
             
                            )
         | 
| 407 | 
            +
                            difficulty = gr.Dropdown(
         | 
| 408 | 
            +
                                choices=[
         | 
| 409 | 
            +
                                    ("High School", "high school"),
         | 
| 410 | 
            +
                                    ("College", "college"),
         | 
| 411 | 
            +
                                    ("PhD", "PhD"),
         | 
| 412 | 
            +
                                    ("Mixed", "mixed"),
         | 
| 413 | 
            +
                                ],
         | 
| 414 | 
            +
                                value="mixed",
         | 
| 415 | 
            +
                                label="Difficulty",
         | 
| 416 | 
            +
                                info="Select the comprehension level for the text. Ensure it matches the task context.",
         | 
| 417 | 
             
                                interactive=True,
         | 
|  | |
|  | |
|  | |
| 418 | 
             
                            )
         | 
| 419 | 
            +
                            btn_apply_to_sample_dataset = gr.Button(
         | 
| 420 | 
            +
                                "Refresh dataset", variant="secondary", size="sm"
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 421 | 
             
                            )
         | 
|  | |
| 422 | 
             
                        with gr.Column(scale=3):
         | 
| 423 | 
            +
                            dataframe = gr.Dataframe(
         | 
| 424 | 
            +
                                headers=["labels", "text"], wrap=True, height=500, interactive=False
         | 
| 425 | 
            +
                            )
         | 
| 426 |  | 
| 427 | 
             
                    gr.HTML("<hr>")
         | 
| 428 | 
             
                    gr.Markdown("## 3. Generate your dataset")
         | 
| 429 | 
            +
                    with gr.Row(equal_height=False):
         | 
| 430 | 
            +
                        with gr.Column(scale=2):
         | 
| 431 | 
             
                            org_name = get_org_dropdown()
         | 
| 432 | 
             
                            repo_name = gr.Textbox(
         | 
| 433 | 
             
                                label="Repo name",
         | 
|  | |
| 435 | 
             
                                value=f"my-distiset-{str(uuid.uuid4())[:8]}",
         | 
| 436 | 
             
                                interactive=True,
         | 
| 437 | 
             
                            )
         | 
| 438 | 
            +
                            num_rows = gr.Number(
         | 
| 439 | 
             
                                label="Number of rows",
         | 
| 440 | 
             
                                value=10,
         | 
| 441 | 
             
                                interactive=True,
         | 
|  | |
| 450 | 
             
                            btn_push_to_hub = gr.Button("Push to Hub", variant="primary", scale=2)
         | 
| 451 | 
             
                        with gr.Column(scale=3):
         | 
| 452 | 
             
                            success_message = gr.Markdown(visible=True)
         | 
| 453 | 
            +
                            with gr.Accordion(
         | 
| 454 | 
            +
                                "Do you want to go further? Customize and run with Distilabel",
         | 
| 455 | 
            +
                                open=False,
         | 
| 456 | 
            +
                                visible=False,
         | 
| 457 | 
            +
                            ) as pipeline_code_ui:
         | 
| 458 | 
            +
                                code = generate_pipeline_code(
         | 
| 459 | 
            +
                                    system_prompt.value,
         | 
| 460 | 
            +
                                    difficulty=difficulty.value,
         | 
| 461 | 
            +
                                    clarity=clarity.value,
         | 
| 462 | 
            +
                                    labels=labels.value,
         | 
| 463 | 
            +
                                    num_labels=num_labels.value,
         | 
| 464 | 
            +
                                    num_rows=num_rows.value,
         | 
| 465 | 
            +
                                )
         | 
| 466 | 
            +
                                pipeline_code = gr.Code(
         | 
| 467 | 
            +
                                    value=code,
         | 
| 468 | 
            +
                                    language="python",
         | 
| 469 | 
            +
                                    label="Distilabel Pipeline Code",
         | 
| 470 | 
            +
                                )
         | 
| 471 |  | 
| 472 | 
            +
                load_btn.click(
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 473 | 
             
                    fn=generate_system_prompt,
         | 
| 474 | 
            +
                    inputs=[dataset_description, temperature],
         | 
| 475 | 
            +
                    outputs=[system_prompt, labels],
         | 
| 476 | 
             
                    show_progress=True,
         | 
| 477 | 
             
                ).then(
         | 
| 478 | 
             
                    fn=generate_sample_dataset,
         | 
| 479 | 
            +
                    inputs=[system_prompt, difficulty, clarity, labels, num_labels],
         | 
| 480 | 
             
                    outputs=[dataframe],
         | 
| 481 | 
             
                    show_progress=True,
         | 
|  | |
|  | |
|  | |
|  | |
| 482 | 
             
                ).then(
         | 
| 483 | 
             
                    fn=update_max_num_labels,
         | 
| 484 | 
             
                    inputs=[labels],
         | 
| 485 | 
             
                    outputs=[num_labels],
         | 
| 486 | 
             
                )
         | 
| 487 |  | 
| 488 | 
            +
                labels.input(
         | 
| 489 | 
            +
                    fn=update_max_num_labels,
         | 
| 490 | 
            +
                    inputs=[labels],
         | 
| 491 | 
            +
                    outputs=[num_labels],
         | 
| 492 | 
            +
                )
         | 
| 493 | 
            +
             | 
| 494 | 
            +
                btn_apply_to_sample_dataset.click(
         | 
| 495 | 
            +
                    fn=generate_sample_dataset,
         | 
| 496 | 
            +
                    inputs=[system_prompt, difficulty, clarity, labels, num_labels],
         | 
| 497 | 
            +
                    outputs=[dataframe],
         | 
| 498 | 
            +
                    show_progress=True,
         | 
| 499 | 
            +
                )
         | 
| 500 | 
            +
             | 
| 501 | 
             
                btn_push_to_hub.click(
         | 
| 502 | 
             
                    fn=validate_argilla_user_workspace_dataset,
         | 
| 503 | 
             
                    inputs=[repo_name],
         | 
|  | |
| 513 | 
             
                    outputs=[success_message],
         | 
| 514 | 
             
                    show_progress=True,
         | 
| 515 | 
             
                ).success(
         | 
| 516 | 
            +
                    fn=hide_pipeline_code_visibility,
         | 
| 517 | 
            +
                    inputs=[],
         | 
| 518 | 
            +
                    outputs=[pipeline_code_ui],
         | 
| 519 | 
            +
                ).success(
         | 
| 520 | 
            +
                    fn=push_dataset,
         | 
| 521 | 
             
                    inputs=[
         | 
| 522 | 
             
                        org_name,
         | 
| 523 | 
             
                        repo_name,
         | 
|  | |
| 525 | 
             
                        difficulty,
         | 
| 526 | 
             
                        clarity,
         | 
| 527 | 
             
                        num_labels,
         | 
| 528 | 
            +
                        num_rows,
         | 
| 529 | 
             
                        labels,
         | 
| 530 | 
             
                        private,
         | 
| 531 | 
             
                    ],
         | 
| 532 | 
             
                    outputs=[success_message],
         | 
| 533 | 
             
                    show_progress=True,
         | 
| 534 | 
             
                ).success(
         | 
| 535 | 
            +
                    fn=show_success_message,
         | 
| 536 | 
             
                    inputs=[org_name, repo_name],
         | 
| 537 | 
             
                    outputs=[success_message],
         | 
| 538 | 
            +
                ).success(
         | 
| 539 | 
            +
                    fn=generate_pipeline_code,
         | 
| 540 | 
            +
                    inputs=[
         | 
| 541 | 
            +
                        system_prompt,
         | 
| 542 | 
            +
                        difficulty,
         | 
| 543 | 
            +
                        clarity,
         | 
| 544 | 
            +
                        labels,
         | 
| 545 | 
            +
                        num_labels,
         | 
| 546 | 
            +
                        num_rows,
         | 
| 547 | 
            +
                    ],
         | 
| 548 | 
            +
                    outputs=[pipeline_code],
         | 
| 549 | 
            +
                ).success(
         | 
| 550 | 
            +
                    fn=show_pipeline_code_visibility,
         | 
| 551 | 
            +
                    inputs=[],
         | 
| 552 | 
            +
                    outputs=[pipeline_code_ui],
         | 
| 553 | 
             
                )
         | 
| 554 | 
            +
             | 
| 555 | 
            +
                app.load(fn=swap_visibility, outputs=main_ui)
         | 
| 556 | 
             
                app.load(fn=get_org_dropdown, outputs=[org_name])
         | 
| @@ -0,0 +1,205 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from typing import List
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            from datasets import get_dataset_config_names, get_dataset_split_names
         | 
| 4 | 
            +
            from distilabel.llms import InferenceEndpointsLLM
         | 
| 5 | 
            +
            from distilabel.steps.tasks import (
         | 
| 6 | 
            +
                UltraFeedback,
         | 
| 7 | 
            +
                TextGeneration,
         | 
| 8 | 
            +
            )
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            from src.distilabel_dataset_generator.pipelines.base import (
         | 
| 11 | 
            +
                MODEL,
         | 
| 12 | 
            +
                _get_next_api_key,
         | 
| 13 | 
            +
            )
         | 
| 14 | 
            +
            from src.distilabel_dataset_generator.utils import extract_column_names
         | 
| 15 | 
            +
             | 
| 16 | 
            +
             | 
| 17 | 
            +
            def get_ultrafeedback_evaluator(aspect, is_sample):
         | 
| 18 | 
            +
                ultrafeedback_evaluator = UltraFeedback(
         | 
| 19 | 
            +
                    llm=InferenceEndpointsLLM(
         | 
| 20 | 
            +
                        model_id=MODEL,
         | 
| 21 | 
            +
                        tokenizer_id=MODEL,
         | 
| 22 | 
            +
                        api_key=_get_next_api_key(),
         | 
| 23 | 
            +
                        generation_kwargs={
         | 
| 24 | 
            +
                            "temperature": 0.7,
         | 
| 25 | 
            +
                            "max_new_tokens": 256 if is_sample else 2048,
         | 
| 26 | 
            +
                        },
         | 
| 27 | 
            +
                    ),
         | 
| 28 | 
            +
                    aspect=aspect,
         | 
| 29 | 
            +
                )
         | 
| 30 | 
            +
                ultrafeedback_evaluator.load()
         | 
| 31 | 
            +
                return ultrafeedback_evaluator
         | 
| 32 | 
            +
             | 
| 33 | 
            +
             | 
| 34 | 
            +
            def get_custom_evaluator(prompt_template, structured_output, columns, is_sample):
         | 
| 35 | 
            +
                custom_evaluator = TextGeneration(
         | 
| 36 | 
            +
                    llm=InferenceEndpointsLLM(
         | 
| 37 | 
            +
                        model_id=MODEL,
         | 
| 38 | 
            +
                        tokenizer_id=MODEL,
         | 
| 39 | 
            +
                        api_key=_get_next_api_key(),
         | 
| 40 | 
            +
                        structured_output={"format": "json", "schema": structured_output},
         | 
| 41 | 
            +
                        generation_kwargs={
         | 
| 42 | 
            +
                            "temperature": 0.7,
         | 
| 43 | 
            +
                            "max_new_tokens": 256 if is_sample else 2048,
         | 
| 44 | 
            +
                        },
         | 
| 45 | 
            +
                    ),
         | 
| 46 | 
            +
                    template=prompt_template,
         | 
| 47 | 
            +
                    columns=columns
         | 
| 48 | 
            +
                )
         | 
| 49 | 
            +
                custom_evaluator.load()
         | 
| 50 | 
            +
                return custom_evaluator
         | 
| 51 | 
            +
             | 
| 52 | 
            +
             | 
| 53 | 
            +
            def generate_ultrafeedback_pipeline_code(
         | 
| 54 | 
            +
                repo_id, subset, split, aspects, instruction_column, response_columns, num_rows
         | 
| 55 | 
            +
            ):
         | 
| 56 | 
            +
                if len(aspects) == 1:
         | 
| 57 | 
            +
                    code = f"""
         | 
| 58 | 
            +
            # Requirements: `pip install distilabel[hf-inference-endpoints]`
         | 
| 59 | 
            +
            import os
         | 
| 60 | 
            +
            from datasets import load_dataset
         | 
| 61 | 
            +
            from distilabel.pipeline import Pipeline
         | 
| 62 | 
            +
            from distilabel.steps import LoadDataFromDicts
         | 
| 63 | 
            +
            from distilabel.steps.tasks import UltraFeedback
         | 
| 64 | 
            +
            from distilabel.llms import InferenceEndpointsLLM
         | 
| 65 | 
            +
             | 
| 66 | 
            +
            MODEL = "{MODEL}"
         | 
| 67 | 
            +
            os.environ["HF_TOKEN"] = "hf_xxx" # https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&canReadGatedRepos=true&tokenType=fineGrained
         | 
| 68 | 
            +
             | 
| 69 | 
            +
            hf_ds = load_dataset("{repo_id}", "{subset}", split="{split}[:{num_rows}]")
         | 
| 70 | 
            +
            data = preprocess_data(hf_ds, "{instruction_column}", "{response_columns}") # to get a list of dictionaries
         | 
| 71 | 
            +
             | 
| 72 | 
            +
            with Pipeline(name="ultrafeedback") as pipeline:
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                load_the_dataset = LoadDataFromDicts(
         | 
| 75 | 
            +
                    data = data,
         | 
| 76 | 
            +
                )
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                ultrafeedback_evaluator = UltraFeedback(
         | 
| 79 | 
            +
                    llm=InferenceEndpointsLLM(
         | 
| 80 | 
            +
                        model_id=MODEL,
         | 
| 81 | 
            +
                        tokenizer_id=MODEL,
         | 
| 82 | 
            +
                        api_key=os.environ["HF_TOKEN"],
         | 
| 83 | 
            +
                        generation_kwargs={{
         | 
| 84 | 
            +
                            "temperature": 0.7,
         | 
| 85 | 
            +
                            "max_new_tokens": 2048,
         | 
| 86 | 
            +
                        }},
         | 
| 87 | 
            +
                    ),
         | 
| 88 | 
            +
                    aspect=aspect,
         | 
| 89 | 
            +
                )
         | 
| 90 | 
            +
                
         | 
| 91 | 
            +
                load_the_dataset >> ultrafeedback_evaluator
         | 
| 92 | 
            +
             | 
| 93 | 
            +
            if __name__ == "__main__":
         | 
| 94 | 
            +
                distiset = pipeline.run()
         | 
| 95 | 
            +
            """
         | 
| 96 | 
            +
                else:
         | 
| 97 | 
            +
                    code = f"""
         | 
| 98 | 
            +
            # Requirements: `pip install distilabel[hf-inference-endpoints]`
         | 
| 99 | 
            +
            import os
         | 
| 100 | 
            +
            from distilabel.pipeline import Pipeline
         | 
| 101 | 
            +
            from distilabel.steps import LoadDataFromDicts, CombineOutputs
         | 
| 102 | 
            +
            from distilabel.steps.tasks import UltraFeedback
         | 
| 103 | 
            +
            from distilabel.llms import InferenceEndpointsLLM
         | 
| 104 | 
            +
             | 
| 105 | 
            +
            MODEL = "{MODEL}"
         | 
| 106 | 
            +
            os.environ["HF_TOKEN"] = "hf_xxx" # https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&canReadGatedRepos=true&tokenType=fineGrained
         | 
| 107 | 
            +
             | 
| 108 | 
            +
            hf_ds = load_dataset("{repo_id}", "{subset}", split="{split}")
         | 
| 109 | 
            +
            data = preprocess_data(hf_ds, "{instruction_column}", "{response_columns}") # to get a list of dictionaries
         | 
| 110 | 
            +
             | 
| 111 | 
            +
            with Pipeline(name="ultrafeedback") as pipeline:
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                load_the_dataset = LoadDataFromDicts(
         | 
| 114 | 
            +
                    data = data,
         | 
| 115 | 
            +
                )
         | 
| 116 | 
            +
                
         | 
| 117 | 
            +
                tasks = []
         | 
| 118 | 
            +
                for aspect in aspects:
         | 
| 119 | 
            +
                    evaluate_responses = UltraFeedback(
         | 
| 120 | 
            +
                        name=f"evaluate-responses-{{aspect}}",
         | 
| 121 | 
            +
                        aspect=aspect,
         | 
| 122 | 
            +
                        llm=InferenceEndpointsLLM(
         | 
| 123 | 
            +
                            model_id=MODEL,
         | 
| 124 | 
            +
                            tokenizer_id=MODEL,
         | 
| 125 | 
            +
                            api_key=os.environ["HF_TOKEN"],
         | 
| 126 | 
            +
                            generation_kwargs={{
         | 
| 127 | 
            +
                                "temperature": 0.7,
         | 
| 128 | 
            +
                                "max_new_tokens": 2048,
         | 
| 129 | 
            +
                            }},
         | 
| 130 | 
            +
                        output_mappings={{
         | 
| 131 | 
            +
                            "ratings": f"ratings_{{aspect}}",
         | 
| 132 | 
            +
                            "types": f"type_{{aspect}}",
         | 
| 133 | 
            +
                            "rationales": f"rationales_for_types_{{aspect}}",
         | 
| 134 | 
            +
                            "rationales-for-ratings": f"rationales_for_ratings_{{aspect}}",
         | 
| 135 | 
            +
                        }} if aspect in ["truthfulness", "helpfulness"] else {{"rationales": f"rationales_{{aspect}}", "ratings": f"ratings_{{aspect}}"}},
         | 
| 136 | 
            +
                    )
         | 
| 137 | 
            +
                    tasks.append(evaluate_responses)
         | 
| 138 | 
            +
                
         | 
| 139 | 
            +
                combine_outputs = CombineOutputs()
         | 
| 140 | 
            +
                
         | 
| 141 | 
            +
                load_the_dataset >> tasks >> combine_outputs
         | 
| 142 | 
            +
             | 
| 143 | 
            +
            if __name__ == "__main__":
         | 
| 144 | 
            +
                distiset = pipeline.run()
         | 
| 145 | 
            +
            """
         | 
| 146 | 
            +
                return code
         | 
| 147 | 
            +
             | 
| 148 | 
            +
             | 
| 149 | 
            +
            def generate_custom_pipeline_code(
         | 
| 150 | 
            +
                repo_id, subset, split, prompt_template, structured_output, num_rows
         | 
| 151 | 
            +
            ):
         | 
| 152 | 
            +
                columns = extract_column_names(structured_output)
         | 
| 153 | 
            +
                code = f"""
         | 
| 154 | 
            +
            # Requirements: `pip install distilabel[hf-inference-endpoints, instructor]`
         | 
| 155 | 
            +
            import os
         | 
| 156 | 
            +
            from distilabel.pipeline import Pipeline
         | 
| 157 | 
            +
            from distilabel.steps import LoadDataFromHub
         | 
| 158 | 
            +
            from distilabel.steps.tasks import TextGeneration
         | 
| 159 | 
            +
            from distilabel.llms import InferenceEndpointsLLM
         | 
| 160 | 
            +
             | 
| 161 | 
            +
            MODEL = "{MODEL}"
         | 
| 162 | 
            +
            CUSTOM_TEMPLATE = "{prompt_template}"
         | 
| 163 | 
            +
            os.environ["HF_TOKEN"] = "hf_xxx" # https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&canReadGatedRepos=true&tokenType=fineGrained
         | 
| 164 | 
            +
             | 
| 165 | 
            +
            with Pipeline(name="custom-evaluation") as pipeline:
         | 
| 166 | 
            +
                load_the_dataset = LoadDataFromHub(
         | 
| 167 | 
            +
                    repo_id="{repo_id}",
         | 
| 168 | 
            +
                    config="{subset}",
         | 
| 169 | 
            +
                    split="{split}",
         | 
| 170 | 
            +
                    num_examples={num_rows},
         | 
| 171 | 
            +
                    batch_size=2
         | 
| 172 | 
            +
                )
         | 
| 173 | 
            +
                custom_evaluator = TextGeneration(
         | 
| 174 | 
            +
                    llm=InferenceEndpointsLLM(
         | 
| 175 | 
            +
                        model_id=MODEL,
         | 
| 176 | 
            +
                        tokenizer_id=MODEL,
         | 
| 177 | 
            +
                        api_key=os.environ["HF_TOKEN"],
         | 
| 178 | 
            +
                        structured_output={{"format": "json", "schema": {structured_output}}},
         | 
| 179 | 
            +
                        generation_kwargs={{
         | 
| 180 | 
            +
                            "temperature": 0.7,
         | 
| 181 | 
            +
                            "max_new_tokens": 2048,
         | 
| 182 | 
            +
                        }},
         | 
| 183 | 
            +
                    ),
         | 
| 184 | 
            +
                    template=CUSTOM_TEMPLATE,
         | 
| 185 | 
            +
                    columns={columns}
         | 
| 186 | 
            +
                )
         | 
| 187 | 
            +
                
         | 
| 188 | 
            +
                load_the_dataset >> custom_evaluator
         | 
| 189 | 
            +
             | 
| 190 | 
            +
            if __name__ == "__main__":
         | 
| 191 | 
            +
                distiset = pipeline.run()
         | 
| 192 | 
            +
            """
         | 
| 193 | 
            +
                return code
         | 
| 194 | 
            +
             | 
| 195 | 
            +
             | 
| 196 | 
            +
            def generate_pipeline_code(repo_id, aspects, instruction_column, response_columns, prompt_template, structured_output, num_rows, eval_type):
         | 
| 197 | 
            +
                if repo_id is None:
         | 
| 198 | 
            +
                    subset = "default"
         | 
| 199 | 
            +
                    split = "train"
         | 
| 200 | 
            +
                else:
         | 
| 201 | 
            +
                    subset = get_dataset_config_names(repo_id)[0]
         | 
| 202 | 
            +
                    split = get_dataset_split_names(repo_id, subset)[0]
         | 
| 203 | 
            +
                if eval_type == "ultrafeedback":
         | 
| 204 | 
            +
                    return generate_ultrafeedback_pipeline_code(repo_id, subset, split, aspects, instruction_column, response_columns, num_rows)
         | 
| 205 | 
            +
                return generate_custom_pipeline_code(repo_id, subset, split, prompt_template, structured_output, num_rows)
         | 
| @@ -138,52 +138,26 @@ def _get_output_mappings(num_turns): | |
| 138 | 
             
                    return {"conversation": "messages"}
         | 
| 139 |  | 
| 140 |  | 
| 141 | 
            -
            def  | 
| 142 | 
            -
                 | 
| 143 | 
            -
                code = f"""
         | 
| 144 | 
            -
            # Requirements: `pip install distilabel[hf-inference-endpoints]`
         | 
| 145 | 
            -
            import os
         | 
| 146 | 
            -
            from distilabel.pipeline import Pipeline
         | 
| 147 | 
            -
            from distilabel.steps import KeepColumns
         | 
| 148 | 
            -
            from distilabel.steps.tasks import MagpieGenerator
         | 
| 149 | 
            -
            from distilabel.llms import InferenceEndpointsLLM
         | 
| 150 | 
            -
             | 
| 151 | 
            -
            MODEL = "{MODEL}"
         | 
| 152 | 
            -
            SYSTEM_PROMPT = "{system_prompt}"
         | 
| 153 | 
            -
            os.environ["HF_TOKEN"] = "hf_xxx" # https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&canReadGatedRepos=true&tokenType=fineGrained
         | 
| 154 | 
            -
             | 
| 155 | 
            -
            with Pipeline(name="sft") as pipeline:
         | 
| 156 | 
            -
                magpie = MagpieGenerator(
         | 
| 157 | 
             
                    llm=InferenceEndpointsLLM(
         | 
|  | |
| 158 | 
             
                        model_id=MODEL,
         | 
| 159 | 
             
                        tokenizer_id=MODEL,
         | 
| 160 | 
            -
                         | 
| 161 | 
            -
             | 
| 162 | 
            -
                            "temperature": 0.9,
         | 
| 163 | 
            -
                            "do_sample": True,
         | 
| 164 | 
             
                            "max_new_tokens": 2048,
         | 
| 165 | 
            -
                            " | 
| 166 | 
            -
                        } | 
| 167 | 
            -
                        api_key=os.environ["HF_TOKEN"],
         | 
| 168 | 
             
                    ),
         | 
| 169 | 
            -
                     | 
| 170 | 
            -
                     | 
| 171 | 
            -
                    batch_size=1,
         | 
| 172 | 
            -
                    system_prompt=SYSTEM_PROMPT,
         | 
| 173 | 
            -
                    output_mappings={input_mappings},
         | 
| 174 | 
            -
                )
         | 
| 175 | 
            -
                keep_columns = KeepColumns(
         | 
| 176 | 
            -
                    columns={list(input_mappings.values())} + ["model_name"],
         | 
| 177 | 
             
                )
         | 
| 178 | 
            -
                 | 
| 179 | 
            -
             | 
| 180 | 
            -
            if __name__ == "__main__":
         | 
| 181 | 
            -
                distiset = pipeline.run()
         | 
| 182 | 
            -
            """
         | 
| 183 | 
            -
                return code
         | 
| 184 |  | 
| 185 |  | 
| 186 | 
            -
            def get_magpie_generator( | 
| 187 | 
             
                input_mappings = _get_output_mappings(num_turns)
         | 
| 188 | 
             
                output_mappings = input_mappings.copy()
         | 
| 189 | 
             
                if num_turns == 1:
         | 
| @@ -228,7 +202,7 @@ def get_magpie_generator(num_turns, num_rows, system_prompt, is_sample): | |
| 228 | 
             
                return magpie_generator
         | 
| 229 |  | 
| 230 |  | 
| 231 | 
            -
            def get_response_generator( | 
| 232 | 
             
                if num_turns == 1:
         | 
| 233 | 
             
                    response_generator = TextGeneration(
         | 
| 234 | 
             
                        llm=InferenceEndpointsLLM(
         | 
| @@ -262,19 +236,46 @@ def get_response_generator(num_turns, system_prompt, is_sample): | |
| 262 | 
             
                return response_generator
         | 
| 263 |  | 
| 264 |  | 
| 265 | 
            -
            def  | 
| 266 | 
            -
                 | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 267 | 
             
                    llm=InferenceEndpointsLLM(
         | 
| 268 | 
            -
                        api_key=_get_next_api_key(),
         | 
| 269 | 
             
                        model_id=MODEL,
         | 
| 270 | 
             
                        tokenizer_id=MODEL,
         | 
| 271 | 
            -
                         | 
| 272 | 
            -
             | 
| 273 | 
            -
                            " | 
| 274 | 
             
                            "do_sample": True,
         | 
| 275 | 
            -
             | 
|  | |
|  | |
|  | |
| 276 | 
             
                    ),
         | 
| 277 | 
            -
                     | 
|  | |
|  | |
|  | |
|  | |
| 278 | 
             
                )
         | 
| 279 | 
            -
                 | 
| 280 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 138 | 
             
                    return {"conversation": "messages"}
         | 
| 139 |  | 
| 140 |  | 
| 141 | 
            +
            def get_prompt_generator(temperature):
         | 
| 142 | 
            +
                prompt_generator = TextGeneration(
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 143 | 
             
                    llm=InferenceEndpointsLLM(
         | 
| 144 | 
            +
                        api_key=_get_next_api_key(),
         | 
| 145 | 
             
                        model_id=MODEL,
         | 
| 146 | 
             
                        tokenizer_id=MODEL,
         | 
| 147 | 
            +
                        generation_kwargs={
         | 
| 148 | 
            +
                            "temperature": temperature,
         | 
|  | |
|  | |
| 149 | 
             
                            "max_new_tokens": 2048,
         | 
| 150 | 
            +
                            "do_sample": True,
         | 
| 151 | 
            +
                        },
         | 
|  | |
| 152 | 
             
                    ),
         | 
| 153 | 
            +
                    system_prompt=PROMPT_CREATION_PROMPT,
         | 
| 154 | 
            +
                    use_system_prompt=True,
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 155 | 
             
                )
         | 
| 156 | 
            +
                prompt_generator.load()
         | 
| 157 | 
            +
                return prompt_generator
         | 
|  | |
|  | |
|  | |
|  | |
| 158 |  | 
| 159 |  | 
| 160 | 
            +
            def get_magpie_generator(system_prompt, num_turns, is_sample):
         | 
| 161 | 
             
                input_mappings = _get_output_mappings(num_turns)
         | 
| 162 | 
             
                output_mappings = input_mappings.copy()
         | 
| 163 | 
             
                if num_turns == 1:
         | 
|  | |
| 202 | 
             
                return magpie_generator
         | 
| 203 |  | 
| 204 |  | 
| 205 | 
            +
            def get_response_generator(system_prompt, num_turns, is_sample):
         | 
| 206 | 
             
                if num_turns == 1:
         | 
| 207 | 
             
                    response_generator = TextGeneration(
         | 
| 208 | 
             
                        llm=InferenceEndpointsLLM(
         | 
|  | |
| 236 | 
             
                return response_generator
         | 
| 237 |  | 
| 238 |  | 
| 239 | 
            +
            def generate_pipeline_code(system_prompt, num_turns, num_rows):
         | 
| 240 | 
            +
                input_mappings = _get_output_mappings(num_turns)
         | 
| 241 | 
            +
                code = f"""
         | 
| 242 | 
            +
            # Requirements: `pip install distilabel[hf-inference-endpoints]`
         | 
| 243 | 
            +
            import os
         | 
| 244 | 
            +
            from distilabel.pipeline import Pipeline
         | 
| 245 | 
            +
            from distilabel.steps import KeepColumns
         | 
| 246 | 
            +
            from distilabel.steps.tasks import MagpieGenerator
         | 
| 247 | 
            +
            from distilabel.llms import InferenceEndpointsLLM
         | 
| 248 | 
            +
             | 
| 249 | 
            +
            MODEL = "{MODEL}"
         | 
| 250 | 
            +
            SYSTEM_PROMPT = "{system_prompt}"
         | 
| 251 | 
            +
            os.environ["HF_TOKEN"] = "hf_xxx" # https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&canReadGatedRepos=true&tokenType=fineGrained
         | 
| 252 | 
            +
             | 
| 253 | 
            +
            with Pipeline(name="sft") as pipeline:
         | 
| 254 | 
            +
                magpie = MagpieGenerator(
         | 
| 255 | 
             
                    llm=InferenceEndpointsLLM(
         | 
|  | |
| 256 | 
             
                        model_id=MODEL,
         | 
| 257 | 
             
                        tokenizer_id=MODEL,
         | 
| 258 | 
            +
                        magpie_pre_query_template="llama3",
         | 
| 259 | 
            +
                        generation_kwargs={{
         | 
| 260 | 
            +
                            "temperature": 0.9,
         | 
| 261 | 
             
                            "do_sample": True,
         | 
| 262 | 
            +
                            "max_new_tokens": 2048,
         | 
| 263 | 
            +
                            "stop_sequences": {_STOP_SEQUENCES}
         | 
| 264 | 
            +
                        }},
         | 
| 265 | 
            +
                        api_key=os.environ["HF_TOKEN"],
         | 
| 266 | 
             
                    ),
         | 
| 267 | 
            +
                    n_turns={num_turns},
         | 
| 268 | 
            +
                    num_rows={num_rows},
         | 
| 269 | 
            +
                    batch_size=1,
         | 
| 270 | 
            +
                    system_prompt=SYSTEM_PROMPT,
         | 
| 271 | 
            +
                    output_mappings={input_mappings},
         | 
| 272 | 
             
                )
         | 
| 273 | 
            +
                keep_columns = KeepColumns(
         | 
| 274 | 
            +
                    columns={list(input_mappings.values())} + ["model_name"],
         | 
| 275 | 
            +
                )
         | 
| 276 | 
            +
                magpie.connect(keep_columns)
         | 
| 277 | 
            +
             | 
| 278 | 
            +
            if __name__ == "__main__":
         | 
| 279 | 
            +
                distiset = pipeline.run()
         | 
| 280 | 
            +
            """
         | 
| 281 | 
            +
                return code
         | 
| @@ -1,4 +1,5 @@ | |
| 1 | 
             
            import random
         | 
|  | |
| 2 | 
             
            from typing import List
         | 
| 3 |  | 
| 4 | 
             
            from distilabel.llms import InferenceEndpointsLLM
         | 
| @@ -22,25 +23,27 @@ The prompt you write should follow the same style and structure as the following | |
| 22 |  | 
| 23 | 
             
            If a label is composed of multiple words, use a hyphen to separate them. For example, 'smartphone-review', 'customer-service', 'product-quality'.:
         | 
| 24 |  | 
| 25 | 
            -
            Classify the following customer review of a cinema as  | 
| 26 |  | 
| 27 | 
            -
             | 
| 28 |  | 
| 29 | 
            -
             | 
| 30 |  | 
| 31 | 
            -
             | 
| 32 |  | 
| 33 | 
            -
             | 
| 34 |  | 
| 35 | 
            -
             | 
| 36 |  | 
| 37 | 
            -
            Categorize the following  | 
| 38 |  | 
| 39 | 
            -
            Classify the following  | 
| 40 |  | 
| 41 | 
            -
             | 
| 42 |  | 
| 43 | 
            -
            Classify the following  | 
|  | |
|  | |
| 44 |  | 
| 45 | 
             
            User dataset description:
         | 
| 46 | 
             
            """
         | 
| @@ -51,6 +54,82 @@ DEFAULT_DATASET_DESCRIPTIONS = [ | |
| 51 | 
             
            ]
         | 
| 52 |  | 
| 53 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 54 | 
             
            def generate_pipeline_code(
         | 
| 55 | 
             
                system_prompt: str,
         | 
| 56 | 
             
                difficulty: str = None,
         | 
| @@ -146,63 +225,3 @@ with Pipeline(name="textcat") as pipeline: | |
| 146 | 
             
                    distiset = pipeline.run()
         | 
| 147 | 
             
                """
         | 
| 148 | 
             
                )
         | 
| 149 | 
            -
             | 
| 150 | 
            -
             | 
| 151 | 
            -
            def get_textcat_generator(difficulty, clarity, is_sample):
         | 
| 152 | 
            -
                textcat_generator = GenerateTextClassificationData(
         | 
| 153 | 
            -
                    llm=InferenceEndpointsLLM(
         | 
| 154 | 
            -
                        model_id=MODEL,
         | 
| 155 | 
            -
                        tokenizer_id=MODEL,
         | 
| 156 | 
            -
                        api_key=_get_next_api_key(),
         | 
| 157 | 
            -
                        generation_kwargs={
         | 
| 158 | 
            -
                            "temperature": 0.9,
         | 
| 159 | 
            -
                            "max_new_tokens": 256 if is_sample else 2048,
         | 
| 160 | 
            -
                            "do_sample": True,
         | 
| 161 | 
            -
                            "top_k": 50,
         | 
| 162 | 
            -
                            "top_p": 0.95,
         | 
| 163 | 
            -
                        },
         | 
| 164 | 
            -
                    ),
         | 
| 165 | 
            -
                    difficulty=None if difficulty == "mixed" else difficulty,
         | 
| 166 | 
            -
                    clarity=None if clarity == "mixed" else clarity,
         | 
| 167 | 
            -
                    seed=random.randint(0, 2**32 - 1),
         | 
| 168 | 
            -
                )
         | 
| 169 | 
            -
                textcat_generator.load()
         | 
| 170 | 
            -
                return textcat_generator
         | 
| 171 | 
            -
             | 
| 172 | 
            -
             | 
| 173 | 
            -
            def get_labeller_generator(system_prompt, labels, num_labels):
         | 
| 174 | 
            -
                labeller_generator = TextClassification(
         | 
| 175 | 
            -
                    llm=InferenceEndpointsLLM(
         | 
| 176 | 
            -
                        model_id=MODEL,
         | 
| 177 | 
            -
                        tokenizer_id=MODEL,
         | 
| 178 | 
            -
                        api_key=_get_next_api_key(),
         | 
| 179 | 
            -
                        generation_kwargs={
         | 
| 180 | 
            -
                            "temperature": 0.7,
         | 
| 181 | 
            -
                            "max_new_tokens": 2048,
         | 
| 182 | 
            -
                        },
         | 
| 183 | 
            -
                    ),
         | 
| 184 | 
            -
                    context=system_prompt,
         | 
| 185 | 
            -
                    available_labels=labels,
         | 
| 186 | 
            -
                    n=num_labels,
         | 
| 187 | 
            -
                    default_label="unknown",
         | 
| 188 | 
            -
                )
         | 
| 189 | 
            -
                labeller_generator.load()
         | 
| 190 | 
            -
                return labeller_generator
         | 
| 191 | 
            -
             | 
| 192 | 
            -
             | 
| 193 | 
            -
            def get_prompt_generator():
         | 
| 194 | 
            -
                prompt_generator = TextGeneration(
         | 
| 195 | 
            -
                    llm=InferenceEndpointsLLM(
         | 
| 196 | 
            -
                        api_key=_get_next_api_key(),
         | 
| 197 | 
            -
                        model_id=MODEL,
         | 
| 198 | 
            -
                        tokenizer_id=MODEL,
         | 
| 199 | 
            -
                        generation_kwargs={
         | 
| 200 | 
            -
                            "temperature": 0.8,
         | 
| 201 | 
            -
                            "max_new_tokens": 2048,
         | 
| 202 | 
            -
                            "do_sample": True,
         | 
| 203 | 
            -
                        },
         | 
| 204 | 
            -
                    ),
         | 
| 205 | 
            -
                    use_system_prompt=True,
         | 
| 206 | 
            -
                )
         | 
| 207 | 
            -
                prompt_generator.load()
         | 
| 208 | 
            -
                return prompt_generator
         | 
|  | |
| 1 | 
             
            import random
         | 
| 2 | 
            +
            from pydantic import BaseModel, Field
         | 
| 3 | 
             
            from typing import List
         | 
| 4 |  | 
| 5 | 
             
            from distilabel.llms import InferenceEndpointsLLM
         | 
|  | |
| 23 |  | 
| 24 | 
             
            If a label is composed of multiple words, use a hyphen to separate them. For example, 'smartphone-review', 'customer-service', 'product-quality'.:
         | 
| 25 |  | 
| 26 | 
            +
            {"classification_task": "Classify the following customer review of a cinema as", "labels": ["positive", "negative"]}
         | 
| 27 |  | 
| 28 | 
            +
            {"classification_task": "Categorize the following news article into one or more of the following categories:", "labels": ["politics", "sports", "technology", "entertainment", "health", "business", "environment", "education", "science", "international"]}
         | 
| 29 |  | 
| 30 | 
            +
            {"classification_task": "Classify the following news article into one or more of the following categories:", "labels": ['politics', 'sports', 'technology', 'entertainment', 'health', 'business', 'environment', 'education', 'science', 'international']}
         | 
| 31 |  | 
| 32 | 
            +
            {"classification_task": "Determine the sentiment of the following social media post:", "labels": ['ambiguous', 'sarcastic', 'informative', 'emotional']}
         | 
| 33 |  | 
| 34 | 
            +
            {"classification_task": "Identify the issue category for the following technical support ticket:", "labels": ['billing', 'technical', 'account', 'shipping', 'returns', 'installation', 'subscription']}
         | 
| 35 |  | 
| 36 | 
            +
            {"classification_task": "Classify the following movie review into one of the following categories:", "labels": ['critical', 'praise', 'disappointed', 'enthusiastic']}
         | 
| 37 |  | 
| 38 | 
            +
            {"classification_task": "Categorize the following customer service transcript into one of the following categories:", "labels": ['satisfied', 'dissatisfied', 'highly-satisfied', 'somewhat-dissatisfied', 'indifferent']}
         | 
| 39 |  | 
| 40 | 
            +
            {"classification_task": "Classify the following product description into one of the following product types:", "labels": ['smartphone', 'laptop', 'tablet', 'smartwatch', 'e-reader', 'headphones']}
         | 
| 41 |  | 
| 42 | 
            +
            {"classification_task": "Categorize the following tweet expressing the political event discussed as", "labels": ['support', 'opposition']}
         | 
| 43 |  | 
| 44 | 
            +
            {"classification_task": "Classify the following restaurant review into one of the following categories:", "labels": ['food-quality', 'service', 'ambiance', 'price']}
         | 
| 45 | 
            +
             | 
| 46 | 
            +
            {"classification_task": "Categorize the following blog post based on its primary fashion trend or style:", "labels": ['casual', 'formal', 'streetwear', 'vintage', 'sustainable-fashion']}
         | 
| 47 |  | 
| 48 | 
             
            User dataset description:
         | 
| 49 | 
             
            """
         | 
|  | |
| 54 | 
             
            ]
         | 
| 55 |  | 
| 56 |  | 
| 57 | 
            +
            class TextClassificationTask(BaseModel):
         | 
| 58 | 
            +
                classification_task: str = Field(
         | 
| 59 | 
            +
                    ...,
         | 
| 60 | 
            +
                    title="classification_task",
         | 
| 61 | 
            +
                    description="The classification task to be performed.",
         | 
| 62 | 
            +
                )
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                labels: list[str] = Field(
         | 
| 65 | 
            +
                    ...,
         | 
| 66 | 
            +
                    title="Labels",
         | 
| 67 | 
            +
                    description="The possible labels for the classification task.",
         | 
| 68 | 
            +
                )
         | 
| 69 | 
            +
             | 
| 70 | 
            +
             | 
| 71 | 
            +
            def get_prompt_generator(temperature):
         | 
| 72 | 
            +
                prompt_generator = TextGeneration(
         | 
| 73 | 
            +
                    llm=InferenceEndpointsLLM(
         | 
| 74 | 
            +
                        api_key=_get_next_api_key(),
         | 
| 75 | 
            +
                        model_id=MODEL,
         | 
| 76 | 
            +
                        tokenizer_id=MODEL,
         | 
| 77 | 
            +
                        structured_output={"format": "json", "schema": TextClassificationTask},
         | 
| 78 | 
            +
                        generation_kwargs={
         | 
| 79 | 
            +
                            "temperature": temperature,
         | 
| 80 | 
            +
                            "max_new_tokens": 2048,
         | 
| 81 | 
            +
                            "do_sample": True,
         | 
| 82 | 
            +
                        },
         | 
| 83 | 
            +
                    ),
         | 
| 84 | 
            +
                    system_prompt=PROMPT_CREATION_PROMPT,
         | 
| 85 | 
            +
                    use_system_prompt=True,
         | 
| 86 | 
            +
                )
         | 
| 87 | 
            +
                prompt_generator.load()
         | 
| 88 | 
            +
                return prompt_generator
         | 
| 89 | 
            +
             | 
| 90 | 
            +
             | 
| 91 | 
            +
            def get_textcat_generator(difficulty, clarity, is_sample):
         | 
| 92 | 
            +
                textcat_generator = GenerateTextClassificationData(
         | 
| 93 | 
            +
                    llm=InferenceEndpointsLLM(
         | 
| 94 | 
            +
                        model_id=MODEL,
         | 
| 95 | 
            +
                        tokenizer_id=MODEL,
         | 
| 96 | 
            +
                        api_key=_get_next_api_key(),
         | 
| 97 | 
            +
                        generation_kwargs={
         | 
| 98 | 
            +
                            "temperature": 0.9,
         | 
| 99 | 
            +
                            "max_new_tokens": 256 if is_sample else 2048,
         | 
| 100 | 
            +
                            "do_sample": True,
         | 
| 101 | 
            +
                            "top_k": 50,
         | 
| 102 | 
            +
                            "top_p": 0.95,
         | 
| 103 | 
            +
                        },
         | 
| 104 | 
            +
                    ),
         | 
| 105 | 
            +
                    difficulty=None if difficulty == "mixed" else difficulty,
         | 
| 106 | 
            +
                    clarity=None if clarity == "mixed" else clarity,
         | 
| 107 | 
            +
                    seed=random.randint(0, 2**32 - 1),
         | 
| 108 | 
            +
                )
         | 
| 109 | 
            +
                textcat_generator.load()
         | 
| 110 | 
            +
                return textcat_generator
         | 
| 111 | 
            +
             | 
| 112 | 
            +
             | 
| 113 | 
            +
            def get_labeller_generator(system_prompt, labels, num_labels):
         | 
| 114 | 
            +
                labeller_generator = TextClassification(
         | 
| 115 | 
            +
                    llm=InferenceEndpointsLLM(
         | 
| 116 | 
            +
                        model_id=MODEL,
         | 
| 117 | 
            +
                        tokenizer_id=MODEL,
         | 
| 118 | 
            +
                        api_key=_get_next_api_key(),
         | 
| 119 | 
            +
                        generation_kwargs={
         | 
| 120 | 
            +
                            "temperature": 0.7,
         | 
| 121 | 
            +
                            "max_new_tokens": 2048,
         | 
| 122 | 
            +
                        },
         | 
| 123 | 
            +
                    ),
         | 
| 124 | 
            +
                    context=system_prompt,
         | 
| 125 | 
            +
                    available_labels=labels,
         | 
| 126 | 
            +
                    n=num_labels,
         | 
| 127 | 
            +
                    default_label="unknown",
         | 
| 128 | 
            +
                )
         | 
| 129 | 
            +
                labeller_generator.load()
         | 
| 130 | 
            +
                return labeller_generator
         | 
| 131 | 
            +
             | 
| 132 | 
            +
             | 
| 133 | 
             
            def generate_pipeline_code(
         | 
| 134 | 
             
                system_prompt: str,
         | 
| 135 | 
             
                difficulty: str = None,
         | 
|  | |
| 225 | 
             
                    distiset = pipeline.run()
         | 
| 226 | 
             
                """
         | 
| 227 | 
             
                )
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
| @@ -1,8 +1,11 @@ | |
|  | |
| 1 | 
             
            import os
         | 
| 2 | 
             
            from typing import List, Optional, Union
         | 
| 3 |  | 
| 4 | 
             
            import argilla as rg
         | 
| 5 | 
             
            import gradio as gr
         | 
|  | |
|  | |
| 6 | 
             
            from gradio.oauth import (
         | 
| 7 | 
             
                OAUTH_CLIENT_ID,
         | 
| 8 | 
             
                OAUTH_CLIENT_SECRET,
         | 
| @@ -11,6 +14,7 @@ from gradio.oauth import ( | |
| 11 | 
             
                get_space,
         | 
| 12 | 
             
            )
         | 
| 13 | 
             
            from huggingface_hub import whoami
         | 
|  | |
| 14 |  | 
| 15 | 
             
            _LOGGED_OUT_CSS = ".main_ui_logged_out{opacity: 0.3; pointer-events: none}"
         | 
| 16 |  | 
| @@ -50,22 +54,22 @@ def list_orgs(oauth_token: OAuthToken = None): | |
| 50 | 
             
                        return []
         | 
| 51 | 
             
                    data = whoami(oauth_token.token)
         | 
| 52 | 
             
                    if data["auth"]["type"] == "oauth":
         | 
| 53 | 
            -
                         | 
| 54 | 
             
                    elif data["auth"]["type"] == "access_token":
         | 
| 55 | 
            -
                         | 
| 56 | 
             
                    else:
         | 
| 57 | 
            -
                         | 
| 58 | 
             
                            entry["entity"]["name"]
         | 
| 59 | 
             
                            for entry in data["auth"]["accessToken"]["fineGrained"]["scoped"]
         | 
| 60 | 
             
                            if "repo.write" in entry["permissions"]
         | 
| 61 | 
             
                        ]
         | 
| 62 | 
            -
                         | 
| 63 | 
            -
                         | 
| 64 | 
             
                except Exception as e:
         | 
| 65 | 
             
                    raise gr.Error(
         | 
| 66 | 
             
                        f"Failed to get organizations: {e}. See if you are logged and connected: https://huggingface.co/settings/connected-applications."
         | 
| 67 | 
             
                    )
         | 
| 68 | 
            -
                return  | 
| 69 |  | 
| 70 |  | 
| 71 | 
             
            def get_org_dropdown(oauth_token: OAuthToken = None):
         | 
| @@ -89,7 +93,7 @@ def get_token(oauth_token: OAuthToken = None): | |
| 89 | 
             
                    return ""
         | 
| 90 |  | 
| 91 |  | 
| 92 | 
            -
            def  | 
| 93 | 
             
                if oauth_token:
         | 
| 94 | 
             
                    return gr.update(elem_classes=["main_ui_logged_in"])
         | 
| 95 | 
             
                else:
         | 
| @@ -132,6 +136,91 @@ def get_argilla_client() -> Union[rg.Argilla, None]: | |
| 132 | 
             
                except Exception:
         | 
| 133 | 
             
                    return None
         | 
| 134 |  | 
| 135 | 
            -
             | 
| 136 | 
             
            def get_preprocess_labels(labels: Optional[List[str]]) -> List[str]:
         | 
| 137 | 
             
                return list(set([label.lower().strip() for label in labels])) if labels else []
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import json
         | 
| 2 | 
             
            import os
         | 
| 3 | 
             
            from typing import List, Optional, Union
         | 
| 4 |  | 
| 5 | 
             
            import argilla as rg
         | 
| 6 | 
             
            import gradio as gr
         | 
| 7 | 
            +
            import numpy as np
         | 
| 8 | 
            +
            import pandas as pd
         | 
| 9 | 
             
            from gradio.oauth import (
         | 
| 10 | 
             
                OAUTH_CLIENT_ID,
         | 
| 11 | 
             
                OAUTH_CLIENT_SECRET,
         | 
|  | |
| 14 | 
             
                get_space,
         | 
| 15 | 
             
            )
         | 
| 16 | 
             
            from huggingface_hub import whoami
         | 
| 17 | 
            +
            from jinja2 import Environment, meta
         | 
| 18 |  | 
| 19 | 
             
            _LOGGED_OUT_CSS = ".main_ui_logged_out{opacity: 0.3; pointer-events: none}"
         | 
| 20 |  | 
|  | |
| 54 | 
             
                        return []
         | 
| 55 | 
             
                    data = whoami(oauth_token.token)
         | 
| 56 | 
             
                    if data["auth"]["type"] == "oauth":
         | 
| 57 | 
            +
                        organizations = [data["name"]] + [org["name"] for org in data["orgs"]]
         | 
| 58 | 
             
                    elif data["auth"]["type"] == "access_token":
         | 
| 59 | 
            +
                        organizations = [org["name"] for org in data["orgs"]]
         | 
| 60 | 
             
                    else:
         | 
| 61 | 
            +
                        organizations = [
         | 
| 62 | 
             
                            entry["entity"]["name"]
         | 
| 63 | 
             
                            for entry in data["auth"]["accessToken"]["fineGrained"]["scoped"]
         | 
| 64 | 
             
                            if "repo.write" in entry["permissions"]
         | 
| 65 | 
             
                        ]
         | 
| 66 | 
            +
                        organizations = [org for org in organizations if org != data["name"]]
         | 
| 67 | 
            +
                        organizations = [data["name"]] + organizations
         | 
| 68 | 
             
                except Exception as e:
         | 
| 69 | 
             
                    raise gr.Error(
         | 
| 70 | 
             
                        f"Failed to get organizations: {e}. See if you are logged and connected: https://huggingface.co/settings/connected-applications."
         | 
| 71 | 
             
                    )
         | 
| 72 | 
            +
                return organizations
         | 
| 73 |  | 
| 74 |  | 
| 75 | 
             
            def get_org_dropdown(oauth_token: OAuthToken = None):
         | 
|  | |
| 93 | 
             
                    return ""
         | 
| 94 |  | 
| 95 |  | 
| 96 | 
            +
            def swap_visibility(oauth_token: Optional[OAuthToken] = None):
         | 
| 97 | 
             
                if oauth_token:
         | 
| 98 | 
             
                    return gr.update(elem_classes=["main_ui_logged_in"])
         | 
| 99 | 
             
                else:
         | 
|  | |
| 136 | 
             
                except Exception:
         | 
| 137 | 
             
                    return None
         | 
| 138 |  | 
|  | |
| 139 | 
             
            def get_preprocess_labels(labels: Optional[List[str]]) -> List[str]:
         | 
| 140 | 
             
                return list(set([label.lower().strip() for label in labels])) if labels else []
         | 
| 141 | 
            +
             | 
| 142 | 
            +
             | 
| 143 | 
            +
            def column_to_list(dataframe: pd.DataFrame, column_name: str) -> List[str]:
         | 
| 144 | 
            +
                if column_name in dataframe.columns:
         | 
| 145 | 
            +
                    return dataframe[column_name].tolist()
         | 
| 146 | 
            +
                else:
         | 
| 147 | 
            +
                    raise ValueError(f"Column '{column_name}' does not exist.")
         | 
| 148 | 
            +
             | 
| 149 | 
            +
             | 
| 150 | 
            +
            def process_columns(
         | 
| 151 | 
            +
                dataframe,
         | 
| 152 | 
            +
                instruction_column: str,
         | 
| 153 | 
            +
                response_columns: Union[str, List[str]],
         | 
| 154 | 
            +
            ) -> List[dict]:
         | 
| 155 | 
            +
                instruction_column = [instruction_column]
         | 
| 156 | 
            +
                if isinstance(response_columns, str):
         | 
| 157 | 
            +
                    response_columns = [response_columns]
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                data = []
         | 
| 160 | 
            +
                for _, row in dataframe.iterrows():
         | 
| 161 | 
            +
                    instruction = ""
         | 
| 162 | 
            +
                    for col in instruction_column:
         | 
| 163 | 
            +
                        value = row[col]
         | 
| 164 | 
            +
                        if isinstance(value, (list, np.ndarray)):
         | 
| 165 | 
            +
                            user_contents = [d["content"] for d in value if d.get("role") == "user"]
         | 
| 166 | 
            +
                            if user_contents:
         | 
| 167 | 
            +
                                instruction = user_contents[-1]
         | 
| 168 | 
            +
                        elif isinstance(value, str):
         | 
| 169 | 
            +
                            try:
         | 
| 170 | 
            +
                                parsed_message = json.loads(value)
         | 
| 171 | 
            +
                                user_contents = [
         | 
| 172 | 
            +
                                    d["content"] for d in parsed_message if d.get("role") == "user"
         | 
| 173 | 
            +
                                ]
         | 
| 174 | 
            +
                                if user_contents:
         | 
| 175 | 
            +
                                    instruction = user_contents[-1]
         | 
| 176 | 
            +
                            except json.JSONDecodeError:
         | 
| 177 | 
            +
                                instruction = value
         | 
| 178 | 
            +
                        else:
         | 
| 179 | 
            +
                            instruction = ""
         | 
| 180 | 
            +
             | 
| 181 | 
            +
                    generations = []
         | 
| 182 | 
            +
                    for col in response_columns:
         | 
| 183 | 
            +
                        value = row[col]
         | 
| 184 | 
            +
                        if isinstance(value, (list, np.ndarray)):
         | 
| 185 | 
            +
                            if all(isinstance(item, dict) and "role" in item for item in value):
         | 
| 186 | 
            +
                                assistant_contents = [
         | 
| 187 | 
            +
                                    d["content"] for d in value if d.get("role") == "assistant"
         | 
| 188 | 
            +
                                ]
         | 
| 189 | 
            +
                                if assistant_contents:
         | 
| 190 | 
            +
                                    generations.append(assistant_contents[-1])
         | 
| 191 | 
            +
                            else:
         | 
| 192 | 
            +
                                generations.extend(value)
         | 
| 193 | 
            +
                        elif isinstance(value, str):
         | 
| 194 | 
            +
                            try:
         | 
| 195 | 
            +
                                parsed_message = json.loads(value)
         | 
| 196 | 
            +
                                assistant_contents = [
         | 
| 197 | 
            +
                                    d["content"]
         | 
| 198 | 
            +
                                    for d in parsed_message
         | 
| 199 | 
            +
                                    if d.get("role") == "assistant"
         | 
| 200 | 
            +
                                ]
         | 
| 201 | 
            +
                                if assistant_contents:
         | 
| 202 | 
            +
                                    generations.append(assistant_contents[-1])
         | 
| 203 | 
            +
                            except json.JSONDecodeError:
         | 
| 204 | 
            +
                                generations.append(value)
         | 
| 205 | 
            +
                        else:
         | 
| 206 | 
            +
                            pass
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                    data.append({"instruction": instruction, "generations": generations})
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                return data
         | 
| 211 | 
            +
             | 
| 212 | 
            +
             | 
| 213 | 
            +
            def extract_column_names(prompt_template: str) -> List[str]:
         | 
| 214 | 
            +
                env = Environment()
         | 
| 215 | 
            +
                parsed_content = env.parse(prompt_template)
         | 
| 216 | 
            +
                variables = meta.find_undeclared_variables(parsed_content)
         | 
| 217 | 
            +
                return list(variables)
         | 
| 218 | 
            +
             | 
| 219 | 
            +
             | 
| 220 | 
            +
            def pad_or_truncate_list(lst, target_length):
         | 
| 221 | 
            +
                lst = lst or []
         | 
| 222 | 
            +
                lst_length = len(lst)
         | 
| 223 | 
            +
                if lst_length >= target_length:
         | 
| 224 | 
            +
                    return lst[-target_length:]
         | 
| 225 | 
            +
                else:
         | 
| 226 | 
            +
                    return lst + [None] * (target_length - lst_length)
         | 
 
			

