Spaces:
Runtime error
Runtime error
| import asyncio | |
| from importlib import import_module | |
| import pandas as pd | |
| import streamlit as st | |
| import weave | |
| from dotenv import load_dotenv | |
| from guardrails_genie.guardrails import GuardrailManager | |
| from guardrails_genie.llm import OpenAIModel | |
| from guardrails_genie.metrics import AccuracyMetric | |
| def initialize_session_state(): | |
| load_dotenv() | |
| if "weave_project_name" not in st.session_state: | |
| st.session_state.weave_project_name = "guardrails-genie" | |
| if "uploaded_file" not in st.session_state: | |
| st.session_state.uploaded_file = None | |
| if "dataset_name" not in st.session_state: | |
| st.session_state.dataset_name = None | |
| if "preview_in_app" not in st.session_state: | |
| st.session_state.preview_in_app = False | |
| if "is_dataset_published" not in st.session_state: | |
| st.session_state.is_dataset_published = False | |
| if "publish_dataset_button" not in st.session_state: | |
| st.session_state.publish_dataset_button = False | |
| if "dataset_ref" not in st.session_state: | |
| st.session_state.dataset_ref = None | |
| if "guardrails" not in st.session_state: | |
| st.session_state.guardrails = [] | |
| if "guardrail_names" not in st.session_state: | |
| st.session_state.guardrail_names = [] | |
| if "start_evaluations_button" not in st.session_state: | |
| st.session_state.start_evaluations_button = False | |
| if "evaluation_name" not in st.session_state: | |
| st.session_state.evaluation_name = "" | |
| def initialize_guardrails(): | |
| st.session_state.guardrails = [] | |
| for guardrail_name in st.session_state.guardrail_names: | |
| if guardrail_name == "PromptInjectionSurveyGuardrail": | |
| survey_guardrail_model = st.sidebar.selectbox( | |
| "Survey Guardrail LLM", ["", "gpt-4o-mini", "gpt-4o"] | |
| ) | |
| if survey_guardrail_model: | |
| st.session_state.guardrails.append( | |
| getattr( | |
| import_module("guardrails_genie.guardrails"), | |
| guardrail_name, | |
| )(llm_model=OpenAIModel(model_name=survey_guardrail_model)) | |
| ) | |
| elif guardrail_name == "PromptInjectionClassifierGuardrail": | |
| classifier_model_name = st.sidebar.selectbox( | |
| "Classifier Guardrail Model", | |
| [ | |
| "", | |
| "ProtectAI/deberta-v3-base-prompt-injection-v2", | |
| "wandb://geekyrakshit/guardrails-genie/model-6rwqup9b:v3", | |
| ], | |
| ) | |
| if classifier_model_name != "": | |
| st.session_state.guardrails.append( | |
| getattr( | |
| import_module("guardrails_genie.guardrails"), | |
| guardrail_name, | |
| )(model_name=classifier_model_name) | |
| ) | |
| elif guardrail_name == "PresidioEntityRecognitionGuardrail": | |
| st.session_state.guardrails.append( | |
| getattr( | |
| import_module("guardrails_genie.guardrails"), | |
| guardrail_name, | |
| )(should_anonymize=True) | |
| ) | |
| elif guardrail_name == "RegexEntityRecognitionGuardrail": | |
| st.session_state.guardrails.append( | |
| getattr( | |
| import_module("guardrails_genie.guardrails"), | |
| guardrail_name, | |
| )(should_anonymize=True) | |
| ) | |
| elif guardrail_name == "TransformersEntityRecognitionGuardrail": | |
| st.session_state.guardrails.append( | |
| getattr( | |
| import_module("guardrails_genie.guardrails"), | |
| guardrail_name, | |
| )(should_anonymize=True) | |
| ) | |
| elif guardrail_name == "RestrictedTermsJudge": | |
| st.session_state.guardrails.append( | |
| getattr( | |
| import_module("guardrails_genie.guardrails"), | |
| guardrail_name, | |
| )(should_anonymize=True) | |
| ) | |
| elif guardrail_name == "PromptInjectionLlamaGuardrail": | |
| llama_guard_checkpoint_name = st.sidebar.text_input( | |
| "Checkpoint Name", value="" | |
| ) | |
| st.session_state.llama_guard_checkpoint_name = llama_guard_checkpoint_name | |
| st.session_state.guardrails.append( | |
| getattr( | |
| import_module("guardrails_genie.guardrails"), | |
| guardrail_name, | |
| )( | |
| checkpoint=( | |
| None | |
| if st.session_state.llama_guard_checkpoint_name == "" | |
| else st.session_state.llama_guard_checkpoint_name | |
| ) | |
| ) | |
| ) | |
| else: | |
| st.session_state.guardrails.append( | |
| getattr( | |
| import_module("guardrails_genie.guardrails"), | |
| guardrail_name, | |
| )() | |
| ) | |
| st.session_state.guardrails_manager = GuardrailManager( | |
| guardrails=st.session_state.guardrails | |
| ) | |
| initialize_session_state() | |
| st.title(":material/monitoring: Evaluation") | |
| weave_project_name = st.sidebar.text_input( | |
| "Weave project name", value=st.session_state.weave_project_name | |
| ) | |
| st.session_state.weave_project_name = weave_project_name | |
| if st.session_state.weave_project_name != "": | |
| weave.init(project_name=st.session_state.weave_project_name) | |
| uploaded_file = st.sidebar.file_uploader( | |
| "Upload the evaluation dataset as a CSV file", type="csv" | |
| ) | |
| st.session_state.uploaded_file = uploaded_file | |
| if st.session_state.uploaded_file is not None: | |
| dataset_name = st.sidebar.text_input("Evaluation dataset name", value=None) | |
| st.session_state.dataset_name = dataset_name | |
| preview_in_app = st.sidebar.toggle("Preview in app", value=False) | |
| st.session_state.preview_in_app = preview_in_app | |
| publish_dataset_button = st.sidebar.button("Publish dataset") | |
| st.session_state.publish_dataset_button = publish_dataset_button | |
| if st.session_state.publish_dataset_button and ( | |
| st.session_state.dataset_name is not None | |
| and st.session_state.dataset_name != "" | |
| ): | |
| with st.expander("Evaluation Dataset Preview", expanded=True): | |
| dataframe = pd.read_csv(st.session_state.uploaded_file) | |
| data_list = dataframe.to_dict(orient="records") | |
| dataset = weave.Dataset(name=st.session_state.dataset_name, rows=data_list) | |
| st.session_state.dataset_ref = weave.publish(dataset) | |
| entity = st.session_state.dataset_ref.entity | |
| project = st.session_state.dataset_ref.project | |
| dataset_name = st.session_state.dataset_name | |
| digest = st.session_state.dataset_ref._digest | |
| dataset_url = f"https://wandb.ai/{entity}/{project}/weave/objects/{dataset_name}/versions/{digest}" | |
| st.markdown(f"Dataset published to [**Weave**]({dataset_url})") | |
| if preview_in_app: | |
| st.dataframe(dataframe.head(20)) | |
| if len(dataframe) > 20: | |
| st.markdown( | |
| f"⚠️ Dataset is too large to preview in app, please explore in the [**Weave UI**]({dataset_url})" | |
| ) | |
| st.session_state.is_dataset_published = True | |
| if st.session_state.is_dataset_published: | |
| guardrail_names = st.sidebar.multiselect( | |
| "Select Guardrails", | |
| options=[ | |
| cls_name | |
| for cls_name, cls_obj in vars( | |
| import_module("guardrails_genie.guardrails") | |
| ).items() | |
| if isinstance(cls_obj, type) and cls_name != "GuardrailManager" | |
| ], | |
| ) | |
| st.session_state.guardrail_names = guardrail_names | |
| initialize_guardrails() | |
| evaluation_name = st.sidebar.text_input("Evaluation Name", value="") | |
| st.session_state.evaluation_name = evaluation_name | |
| start_evaluations_button = st.sidebar.button("Start Evaluations") | |
| st.session_state.start_evaluations_button = start_evaluations_button | |
| if st.session_state.start_evaluations_button: | |
| # st.write(len(st.session_state.guardrails)) | |
| evaluation = weave.Evaluation( | |
| dataset=st.session_state.dataset_ref, | |
| scorers=[AccuracyMetric()], | |
| streamlit_mode=True, | |
| ) | |
| with st.expander("Evaluation Results", expanded=True): | |
| evaluation_summary, call = asyncio.run( | |
| evaluation.evaluate.call( | |
| evaluation, | |
| GuardrailManager(guardrails=st.session_state.guardrails), | |
| __weave={ | |
| "display_name": ( | |
| "Evaluation.evaluate" | |
| if st.session_state.evaluation_name == "" | |
| else "Evaluation.evaluate:" | |
| + st.session_state.evaluation_name | |
| ) | |
| }, | |
| ) | |
| ) | |
| x_axis = list(evaluation_summary["AccuracyMetric"].keys()) | |
| y_axis = [ | |
| evaluation_summary["AccuracyMetric"][x_axis_item] | |
| for x_axis_item in x_axis | |
| ] | |
| st.bar_chart( | |
| pd.DataFrame({"Metric": x_axis, "Score": y_axis}), | |
| x="Metric", | |
| y="Score", | |
| ) | |
| st.markdown( | |
| f"Explore the entire evaluation trace table in [Weave]({call.ui_url})" | |
| ) | |