Spaces:
Sleeping
Sleeping
import gradio as gr | |
# Available models | |
AVAILABLE_MODELS = [ | |
"(Select Model)", | |
"mistralai/Mistral-7B-v0.1", | |
] | |
def create_model_config_section(): | |
""" | |
Creates the "Head to Head - Choose Models" section with two model configurations side by side. | |
Returns the components needed for the main app. | |
""" | |
with gr.Column() as model_config_container: | |
gr.Markdown("## (C) Head to Head - Choose Models to evaluate against each other") | |
with gr.Row(): | |
# Left column - Model 1 configuration | |
with gr.Column(scale=1) as model1_column: | |
with gr.Group(elem_classes=["config-box"]): | |
gr.Markdown("### Model 1") | |
model1_dropdown = gr.Dropdown( | |
choices=AVAILABLE_MODELS, | |
value="(Select Model)", | |
label="Select Model 1", | |
info="Choose the first model for head-to-head comparison" | |
) | |
model1_shots = gr.Slider( | |
minimum=0, | |
maximum=5, | |
value=5, | |
step=1, | |
label="Number of Few-shot Examples", | |
info="Number of examples to use for few-shot learning (0-5)" | |
) | |
model1_regex = gr.Textbox( | |
label="Regex Pattern", | |
placeholder="Optional: Apply regex pattern to model outputs", | |
info="Leave empty for no regex pattern" | |
) | |
model1_flash_attn = gr.Checkbox( | |
label="Use FlashAttention", | |
value=True, | |
info="Use FlashAttention for better performance (if supported by model)" | |
) | |
# Divider in the middle | |
with gr.Column(scale=0.1): | |
gr.Markdown('<div style="border-left: 1px solid #ddd; height: 100%;"></div>', elem_classes=["center-divider"]) | |
# Right column - Model 2 configuration | |
with gr.Column(scale=1) as model2_column: | |
with gr.Group(elem_classes=["config-box"]): | |
gr.Markdown("### Model 2") | |
model2_dropdown = gr.Dropdown( | |
choices=AVAILABLE_MODELS, | |
value="(Select Model)", | |
label="Select Model 2", | |
info="Choose the second model for head-to-head comparison" | |
) | |
model2_shots = gr.Slider( | |
minimum=0, | |
maximum=5, | |
value=5, | |
step=1, | |
label="Number of Few-shot Examples", | |
info="Number of examples to use for few-shot learning (0-5)" | |
) | |
model2_regex = gr.Textbox( | |
label="Regex Pattern", | |
placeholder="Optional: Apply regex pattern to model outputs", | |
info="Leave empty for no regex pattern" | |
) | |
model2_flash_attn = gr.Checkbox( | |
label="Use FlashAttention", | |
value=True, | |
info="Use FlashAttention for better performance (if supported by model)" | |
) | |
# Error message area - initially hidden | |
model_config_error = gr.Markdown( | |
visible=False, | |
value="⚠️ **Error**: Both models and configurations are identical. Please select different models or configurations for comparison.", | |
elem_classes=["error-message"] | |
) | |
return { | |
'container': model_config_container, | |
'model1_dropdown': model1_dropdown, | |
'model1_shots': model1_shots, | |
'model1_regex': model1_regex, | |
'model1_flash_attn': model1_flash_attn, | |
'model2_dropdown': model2_dropdown, | |
'model2_shots': model2_shots, | |
'model2_regex': model2_regex, | |
'model2_flash_attn': model2_flash_attn, | |
'error_message': model_config_error | |
} | |
def validate_model_configs(model1, model1_shots, model1_regex, model1_flash, | |
model2, model2_shots, model2_regex, model2_flash): | |
""" | |
Validates that the two model configurations are not identical. | |
Returns: | |
- bool: Whether the configurations are valid (not identical) | |
- str: Error message if invalid, otherwise empty string | |
""" | |
if model1 == "(Select Model)" or model2 == "(Select Model)": | |
return True, "" | |
# Check if models and all configs are identical | |
if (model1 == model2 and | |
model1_shots == model2_shots and | |
model1_regex == model2_regex and | |
model1_flash == model2_flash): | |
return False, "⚠️ **Error**: Both configurations are identical. Please select different configurations (e.g., number of few-shot examples) for comparison." | |
return True, "" | |
def update_eval_button_state(model1, model1_shots, model1_regex, model1_flash, | |
model2, model2_shots, model2_regex, model2_flash): | |
""" | |
Checks model configurations and updates the error message visibility and eval button state. | |
""" | |
is_valid, error_msg = validate_model_configs( | |
model1, model1_shots, model1_regex, model1_flash, | |
model2, model2_shots, model2_regex, model2_flash | |
) | |
if model1 == "(Select Model)" or model2 == "(Select Model)": | |
return gr.update(visible=False), gr.update(interactive=False) | |
if not is_valid: | |
return gr.update(visible=True, value=error_msg), gr.update(interactive=False) | |
return gr.update(visible=False), gr.update(interactive=True) | |
def get_model_configs(model1, model1_shots, model1_regex, model1_flash, | |
model2, model2_shots, model2_regex, model2_flash): | |
""" | |
Returns the model configurations as structured data for the evaluation function. | |
""" | |
return { | |
"model1": { | |
"name": model1, | |
"shots": model1_shots, | |
"regex": model1_regex, | |
"flash_attention": model1_flash | |
}, | |
"model2": { | |
"name": model2, | |
"shots": model2_shots, | |
"regex": model2_regex, | |
"flash_attention": model2_flash | |
} | |
} |