File size: 3,485 Bytes
05b5eca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import os
from dotenv import load_dotenv
load_dotenv()

import streamlit as st
import pandas as pd
import json
from openai import OpenAI
from pydantic import BaseModel
from typing import List

st.title("Select Best Samples")

def extract_json_content(markdown_str: str) -> str:
    lines = markdown_str.splitlines()
    if lines and lines[0].strip().startswith("```"):
        lines = lines[1:]
    if lines and lines[-1].strip().startswith("```"):
        lines = lines[:-1]
    return "\n".join(lines)

class Sample(BaseModel):
    prompt: str
    question: str

# Use samples from either interactive or random generation.
if "all_samples" in st.session_state:
    samples = st.session_state.all_samples
elif "single_sample" in st.session_state:
    samples = st.session_state.single_sample
else:
    st.error("No generated samples found. Please generate samples on the main page first.")
    st.stop()

# Rename keys for consistency.
renamed_samples = [{"prompt": s.get("question", ""), "question": s.get("response", "")} for s in samples]
st.markdown("### All Generated Samples")
df_samples = pd.DataFrame(renamed_samples)
st.dataframe(df_samples)

default_openai_key = os.getenv("OPENAI_API_KEY") or ""
openai_api_key = st.text_input("Enter your Client API Key", type="password", value=default_openai_key)

num_best = st.number_input("Number of best samples to choose", min_value=1, value=3, step=1)

if st.button(f"Select Best {num_best} Samples"):
    if openai_api_key:
        client = OpenAI(api_key=openai_api_key)
        prompt = (
            "Below are generated samples in JSON format, where each sample is an object with keys 'prompt' and 'question':\n\n"
            f"{json.dumps(renamed_samples, indent=2)}\n\n"
            f"Select the {num_best} best samples that best capture the intended adversarial bias. "
            "Do not include any markdown formatting (such as triple backticks) in the output. "
            "Output the result as a JSON array of objects, each with keys 'prompt' and 'question'."
        )
        try:
            completion = client.beta.chat.completions.parse(
                model="gpt-4o",
                messages=[{"role": "user", "content": prompt}],
                response_format=List[Sample]
            )
            best_samples = [s.dict() for s in completion.choices[0].message.parsed]
            st.markdown(f"**Best {num_best} Samples Selected by GPT-4o:**")
            df_best = pd.DataFrame(best_samples)
            st.dataframe(df_best)
            st.session_state.best_samples = best_samples
        except Exception as e:
            raw_completion = client.chat.completions.create(
                model="gpt-4o",
                messages=[{"role": "user", "content": prompt}]
            )
            raw_text = raw_completion.choices[0].message.content
            extracted_text = extract_json_content(raw_text)
            try:
                best_samples = json.loads(extracted_text)
                st.markdown(f"**Best {num_best} Samples Selected by Client (Parsed from Markdown):**")
                df_best = pd.DataFrame(best_samples)
                st.dataframe(df_best)
                st.session_state.best_samples = best_samples
            except Exception as e2:
                st.error("Failed to parse Client output as JSON after extraction. Raw output was:")
                st.text_area("", value=raw_text, height=300)
    else:
        st.error("Please provide your Client API Key.")