|
|
|
|
|
import os, requests, datetime |
|
import streamlit as st |
|
from functools import partial |
|
from tempfile import NamedTemporaryFile |
|
from typing import List, Callable, Literal, Optional |
|
from streamlit.runtime.uploaded_file_manager import UploadedFile |
|
from langchain_openai import ChatOpenAI |
|
from langchain.schema import HumanMessage, SystemMessage |
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder |
|
from langchain_community.utilities import BingSearchAPIWrapper |
|
from langchain_community.document_loaders import PyPDFLoader |
|
from langchain_community.document_loaders import Docx2txtLoader |
|
from langchain_community.document_loaders import TextLoader |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain_community.vectorstores import FAISS |
|
from langchain_openai import OpenAIEmbeddings |
|
from langchain.tools import Tool |
|
from langchain.tools.retriever import create_retriever_tool |
|
from langchain.agents import create_openai_tools_agent |
|
|
|
from langchain.agents import AgentExecutor |
|
from langchain_community.agent_toolkits.load_tools import load_tools |
|
from langchain.pydantic_v1 import BaseModel, Field |
|
|
|
from TTS.api import TTS |
|
import tempfile |
|
|
|
|
|
tts_model = TTS("tts_models/multilingual/multi-dataset/xtts_v2", gpu=False) |
|
|
|
|
|
|
|
def initialize_session_state_variables() -> None: |
|
""" |
|
Initialize all the session state variables. |
|
""" |
|
session_defaults = { |
|
"ready": False, |
|
"bing_subscription_validity": False, |
|
"model": "gpt-4o", |
|
"language": "English", |
|
"topic": "", |
|
"positive": "", |
|
"negative": "", |
|
"agent_descriptions": {}, |
|
"new_debate": True, |
|
"conversations": [], |
|
"conversations4print": [], |
|
"simulator": None, |
|
"tools": [], |
|
"retriever_tool": None, |
|
"vector_store_message": "", |
|
"conclusions": "", |
|
"comments_key": 0, |
|
"specified_topic": "", |
|
} |
|
for key, value in session_defaults.items(): |
|
if key not in st.session_state: |
|
st.session_state[key] = value |
|
|
|
|
|
def initialize_api_keys(): |
|
""" |
|
Initialize API keys from Hugging Face secrets and validate them. |
|
""" |
|
os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY", "") |
|
os.environ["BING_SUBSCRIPTION_KEY"] = os.getenv("BING_SUBSCRIPTION_KEY", "") |
|
|
|
|
|
if not os.environ["OPENAI_API_KEY"]: |
|
st.error("Missing OPENAI_API_KEY. Add it to Hugging Face Space secrets.") |
|
if not os.environ["BING_SUBSCRIPTION_KEY"]: |
|
st.warning("Missing BING_SUBSCRIPTION_KEY. Bing search may not work.") |
|
|
|
initialize_api_keys() |
|
|
|
|
|
def is_openai_api_key_valid(): |
|
""" |
|
Validate the OpenAI API key from Hugging Face secrets. |
|
""" |
|
openai_api_key = os.environ.get("OPENAI_API_KEY") |
|
if not openai_api_key: |
|
return False |
|
headers = {"Authorization": f"Bearer {openai_api_key}"} |
|
response = requests.get("https://api.openai.com/v1/models", headers=headers) |
|
return response.status_code == 200 |
|
|
|
|
|
|
|
def is_bing_subscription_key_valid(): |
|
""" |
|
Validate the Bing Subscription key from Hugging Face secrets. |
|
""" |
|
bing_subscription_key = os.environ.get("BING_SUBSCRIPTION_KEY") |
|
if not bing_subscription_key: |
|
return False |
|
try: |
|
bing_search = BingSearchAPIWrapper( |
|
bing_subscription_key=bing_subscription_key, |
|
bing_search_url="https://api.bing.microsoft.com/v7.0/search", |
|
k=1 |
|
) |
|
bing_search.run("Test Query") |
|
except Exception: |
|
return False |
|
return True |
|
|
|
|
|
|
|
def check_api_keys() -> None: |
|
""" |
|
Unset this flag to check the validity of the OpenAI API key. |
|
""" |
|
|
|
st.session_state.ready = False |
|
|
|
|
|
def append_period(text: str) -> str: |
|
""" |
|
Append a '.' to the input text |
|
if it is nonempty and does not end with '.' or '?'. |
|
""" |
|
|
|
if text and text[-1] not in (".", "?"): |
|
text += "." |
|
return text |
|
|
|
|
|
def get_vector_store(uploaded_files: List[UploadedFile]) -> Optional[FAISS]: |
|
""" |
|
Take a list of UploadedFile objects as input, |
|
and return a FAISS vector store. |
|
""" |
|
|
|
if not uploaded_files: |
|
return None |
|
|
|
documents = [] |
|
filepaths = [] |
|
loader_map = { |
|
".pdf": PyPDFLoader, |
|
".txt": TextLoader, |
|
".docx": Docx2txtLoader |
|
} |
|
try: |
|
for uploaded_file in uploaded_files: |
|
|
|
with NamedTemporaryFile(dir="files/", delete=False) as file: |
|
file.write(uploaded_file.getbuffer()) |
|
filepath = file.name |
|
filepaths.append(filepath) |
|
|
|
file_ext = os.path.splitext(uploaded_file.name.lower())[1] |
|
loader_class = loader_map.get(file_ext) |
|
if not loader_class: |
|
st.error(f"Unsupported file type: {file_ext}", icon="🚨") |
|
for filepath in filepaths: |
|
if os.path.exists(filepath): |
|
os.remove(filepath) |
|
return None |
|
|
|
|
|
loader = loader_class(filepath) |
|
documents.extend(loader.load()) |
|
|
|
with st.spinner("Vector DB in preparation..."): |
|
|
|
text_splitter = RecursiveCharacterTextSplitter( |
|
chunk_size=1000, |
|
chunk_overlap=200, |
|
|
|
) |
|
doc = text_splitter.split_documents(documents) |
|
|
|
embeddings = OpenAIEmbeddings( |
|
model="text-embedding-3-large", dimensions=1536 |
|
) |
|
vector_store = FAISS.from_documents(doc, embeddings) |
|
except Exception as e: |
|
vector_store = None |
|
st.error(f"An error occurred: {e}", icon="🚨") |
|
finally: |
|
|
|
for filepath in filepaths: |
|
if os.path.exists(filepath): |
|
os.remove(filepath) |
|
|
|
return vector_store |
|
|
|
|
|
def get_retriever() -> None: |
|
""" |
|
Upload document(s), create a vector store, prepare a retriever tool, |
|
save the tool to the variable st.session_state.retriever_tool |
|
""" |
|
|
|
st.write("") |
|
st.write("##### Document(s) to ask about") |
|
uploaded_files = st.file_uploader( |
|
label="Upload an article", |
|
type=["txt", "pdf", "docx"], |
|
accept_multiple_files=True, |
|
label_visibility="collapsed", |
|
) |
|
|
|
left, right = st.columns(2) |
|
if left.button(label="$\:\!$Create a vector DB$\,$"): |
|
|
|
vector_store = get_vector_store(uploaded_files) |
|
|
|
if vector_store is not None: |
|
retriever = vector_store.as_retriever() |
|
st.session_state.retriever_tool = create_retriever_tool( |
|
retriever, |
|
name="retriever", |
|
description=( |
|
"Search for information about the uploaded documents. " |
|
"For any questions about the documents, you must use " |
|
"this tool!" |
|
), |
|
) |
|
st.session_state.vector_store_message = "Vector DB is ready!" |
|
|
|
if st.session_state.vector_store_message: |
|
right.write(f":blue[{st.session_state.vector_store_message}]") |
|
|
|
|
|
class DialogueAgent: |
|
""" |
|
Class for an individual agent participating in the debate. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
name: str, |
|
system_message: SystemMessage, |
|
llm: ChatOpenAI, |
|
tools: List[str], |
|
) -> None: |
|
self.name = name |
|
self.system_message = system_message |
|
self.llm = llm |
|
self.prefix = f"{self.name}: " |
|
self.tools = tools |
|
self.reset() |
|
|
|
def reset(self): |
|
self.message_history = ["\nHere is the conversation so far.\n"] |
|
|
|
def send(self) -> str: |
|
""" |
|
Apply the llm to the message history and return the message string. |
|
""" |
|
chat_prompt_list = [ |
|
("system", "You are a helpful assistant."), |
|
("human", "{input}"), |
|
] |
|
agent_prompt_list = chat_prompt_list + [ |
|
MessagesPlaceholder(variable_name="agent_scratchpad") |
|
] |
|
chat_prompt = ChatPromptTemplate.from_messages(chat_prompt_list) |
|
agent_prompt = ChatPromptTemplate.from_messages(agent_prompt_list) |
|
|
|
if self.tools: |
|
agent = create_openai_tools_agent( |
|
self.llm, self.tools, agent_prompt |
|
) |
|
agent_executor = AgentExecutor( |
|
agent=agent, tools=self.tools, verbose=False |
|
) |
|
else: |
|
agent_executor = chat_prompt | self.llm |
|
|
|
output = agent_executor.invoke( |
|
{ |
|
"input": "\n".join( |
|
[self.system_message.content] |
|
+ self.message_history |
|
+ [self.prefix] |
|
) |
|
} |
|
) |
|
message = output["output"] if self.tools else output.content |
|
return message |
|
|
|
def receive(self, name: str, message: str) -> None: |
|
""" |
|
Concatenate {message} spoken by {name} into message history |
|
""" |
|
self.message_history.append(f"{name}: {message}\n") |
|
|
|
|
|
class DialogueSimulator: |
|
""" |
|
Class for simulating the debate. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
agents: List[DialogueAgent], |
|
selection_function: Callable[[int, List[DialogueAgent]], int], |
|
) -> None: |
|
self.agents = agents |
|
self._step = 0 |
|
self.select_next_speaker = selection_function |
|
|
|
def reset(self): |
|
for agent in self.agents: |
|
agent.reset() |
|
|
|
def inject(self, name: str, message: str): |
|
""" |
|
Initiate the conversation with a {message} from {name} |
|
""" |
|
for agent in self.agents: |
|
agent.receive(name, message) |
|
|
|
|
|
|
|
|
|
def step(self) -> tuple[str, str]: |
|
|
|
speaker_idx = self.select_next_speaker(self._step, self.agents) |
|
speaker = self.agents[speaker_idx] |
|
|
|
|
|
try: |
|
with st.spinner(f"{speaker.name} is thinking..."): |
|
message = speaker.send() |
|
except Exception as e: |
|
st.error(f"An error occurred: {e}", icon="🚨") |
|
st.stop() |
|
|
|
|
|
for receiver in self.agents: |
|
receiver.receive(speaker.name, message) |
|
|
|
|
|
self._step += 1 |
|
|
|
return speaker.name, message |
|
|
|
|
|
def select_next_speaker(step: int, agents: List[DialogueAgent]) -> int: |
|
""" |
|
Return 0, 1, ..., or (the number of agents - 1) corresponding |
|
to the next speaker. |
|
""" |
|
|
|
idx = (step) % len(agents) |
|
return idx |
|
|
|
|
|
def generate_speech(text, speaker_wav=None, language="en"): |
|
""" |
|
Generate speech using xtts-v2. Use a default voice if no speaker WAV is provided. |
|
Args: |
|
text (str): Text to synthesize. |
|
speaker_wav (str or UploadedFile): Path to a speaker WAV file for voice cloning (optional). |
|
language (str): Language of the text. |
|
Returns: |
|
str: Path to the generated audio file. |
|
""" |
|
try: |
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio: |
|
tts_model.tts_to_file( |
|
text=text, |
|
file_path=temp_audio.name, |
|
speaker_wav=speaker_wav.name if speaker_wav else None, |
|
language=language, |
|
) |
|
return temp_audio.name |
|
except Exception as e: |
|
st.error(f"Error generating speech: {e}") |
|
return None |
|
|
|
|
|
def run_simulator(no_of_rounds: int, simulator: DialogueSimulator) -> None: |
|
""" |
|
Simulate a given number of rounds for the debate. |
|
Add TTS audio playback for each dialogue generated. |
|
""" |
|
max_iters = 2 * no_of_rounds |
|
iter = 0 |
|
|
|
|
|
positive_speaker_wav = st.session_state.get("positive_speaker_wav", None) |
|
negative_speaker_wav = st.session_state.get("negative_speaker_wav", None) |
|
|
|
while iter < max_iters: |
|
|
|
name, message = simulator.step() |
|
color = "blue" if iter % 2 == 0 else "red" |
|
message4print = f"**:{color}[{name}]**: {message}" |
|
|
|
|
|
st.session_state.conversations.append(f"{name}: {message}") |
|
st.session_state.conversations4print.append(message4print) |
|
|
|
|
|
st.write(message4print) |
|
|
|
|
|
speaker_wav = ( |
|
positive_speaker_wav if name == st.session_state.positive else negative_speaker_wav |
|
) |
|
audio_path = generate_speech(message, speaker_wav, st.session_state.language) |
|
|
|
|
|
if audio_path: |
|
st.audio(audio_path, format="audio/wav") |
|
|
|
|
|
with open(audio_path, "rb") as audio_file: |
|
st.download_button( |
|
label=f"Download {name}'s Response Audio", |
|
data=audio_file, |
|
file_name=f"{name}_response.wav", |
|
mime="audio/wav", |
|
) |
|
|
|
iter += 1 |
|
|
|
|
|
def generate_agent_description( |
|
name: str, |
|
conversation_description: str, |
|
language: Literal['English', 'Korean'], |
|
word_limit: int |
|
) -> str: |
|
|
|
""" |
|
Generate the description for a participant. |
|
""" |
|
|
|
agent_specifier_prompt = [ |
|
SystemMessage( |
|
content=( |
|
"You can add detail to the description of " |
|
"the conversation participant." |
|
) |
|
), |
|
HumanMessage( |
|
content=( |
|
f"{conversation_description}\n" |
|
f"Please reply with a creative description of '{name}', " |
|
f"in {word_limit} words or less in {language}.\n" |
|
f"Speak directly to '{name}'.\n" |
|
"Give them a point of view.\n" |
|
"Do not add anything else." |
|
) |
|
), |
|
] |
|
agent_specifier_llm = ChatOpenAI( |
|
model=st.session_state.model, temperature=1.0 |
|
) |
|
agent_description = agent_specifier_llm.invoke(agent_specifier_prompt) |
|
|
|
return agent_description.content |
|
|
|
|
|
def generate_system_message( |
|
name: str, |
|
conversation_description: str, |
|
description: str, |
|
language: Literal['English', 'Korean'], |
|
word_limit: int |
|
) -> str: |
|
|
|
""" |
|
Generate the system message for a participant. |
|
""" |
|
|
|
if description: |
|
description_statement = ( |
|
f"Your description is as follows: {description}\n\n" |
|
) |
|
else: |
|
description_statement = "" |
|
|
|
generated_system_message = ( |
|
f"{conversation_description}\n\n" |
|
f"Your name is '{name}'.\n\n" |
|
f"{description_statement}" |
|
"Your goal is to persuade your conversation partner " |
|
"of your point of view.\n\n" |
|
"DO look up information with your tool " |
|
"to refute your partner's claims.\n" |
|
"DO cite your sources.\n\n" |
|
"DO NOT fabricate fake citations.\n" |
|
"DO NOT cite any source that you did not look up.\n\n" |
|
"DO NOT restate something that has been said in the past.\n" |
|
"Do not add anything else.\n\nStop speaking the moment " |
|
"you finish speaking from your perspective.\n\n" |
|
f"Answer in {word_limit} words or less in {language}." |
|
) |
|
return generated_system_message |
|
|
|
|
|
def get_participant_names(topic: str) -> List[str]: |
|
""" |
|
Get the names of the positive and negative for the debate. |
|
""" |
|
|
|
participants = ["positive", "negative"] |
|
participant_names = [] |
|
|
|
for participant in participants: |
|
ex = "AI alarmist" if participant == "negative" else "AI accelerationist" |
|
name_specifier_prompt = [ |
|
SystemMessage(content="You are a helpful moderator for a debate."), |
|
HumanMessage( |
|
content=( |
|
"Here is the topic of conversation: " |
|
f"{append_period(topic)}\n" |
|
f"For the {participant} perspective on the topic, " |
|
"write a name in three words or less. Start the name " |
|
"with a capital letter and do not use ':' .\n" |
|
"For example, for the topic 'The current impact of " |
|
"automation and artificial intelligence on employment', " |
|
f"'{ex}' could serve as an appropriate name for " |
|
f"the {participant} side.\n" |
|
"Use a common noun instead of a proper noun, " |
|
"as shown in the example." |
|
) |
|
), |
|
] |
|
name_specifier_llm = ChatOpenAI( |
|
model=st.session_state.model, temperature=1.0 |
|
) |
|
participant_name = name_specifier_llm.invoke( |
|
name_specifier_prompt |
|
).content |
|
participant_names.append(participant_name) |
|
|
|
return participant_names |
|
|
|
|
|
def continue_debate() -> None: |
|
""" |
|
Unset the new debate flag to signal that the debate has been set up. |
|
""" |
|
|
|
st.session_state.new_debate = False |
|
|
|
|
|
def reset_debate() -> None: |
|
""" |
|
Reset all the session state variables. |
|
""" |
|
|
|
st.session_state.topic = "" |
|
st.session_state.language = "English" |
|
st.session_state.positive = "" |
|
st.session_state.negative = "" |
|
st.session_state.agent_descriptions = {} |
|
st.session_state.specified_topic = "" |
|
st.session_state.new_debate = True |
|
st.session_state.conversations = [] |
|
st.session_state.conversations4print = [] |
|
st.session_state.simulator = None |
|
st.session_state.names = {} |
|
st.session_state.tools = [] |
|
st.session_state.retriever_tool = None |
|
st.session_state.vector_store_message = "" |
|
st.session_state.conclusions = "" |
|
st.session_state.comments_key = 0 |
|
|
|
|
|
def set_tools() -> None: |
|
""" |
|
Set the tools for the agents. Tools that can be selected are |
|
bing_search, arxiv, and retrieval. |
|
""" |
|
|
|
class MySearchToolInput(BaseModel): |
|
query: str = Field(description="search query to look up") |
|
|
|
arxiv = load_tools(["arxiv"])[0] |
|
wikipedia = load_tools(["wikipedia"])[0] |
|
|
|
tool_options = ["ArXiv", "Wikipedia", "Retrieval"] |
|
tool_dictionary = {"ArXiv": arxiv, "Wikipedia": wikipedia} |
|
|
|
if st.session_state.bing_subscription_validity: |
|
search = BingSearchAPIWrapper() |
|
bing_search = Tool( |
|
name="bing_search", |
|
description=( |
|
"A search engine for comprehensive, accurate, and trusted results. " |
|
"Useful for when you need to answer questions about current events. " |
|
"Input should be a search query." |
|
), |
|
func=partial(search.results, num_results=5), |
|
args_schema=MySearchToolInput, |
|
) |
|
tool_options.insert(0, "Search") |
|
tool_dictionary["Search"] = bing_search |
|
|
|
st.write("**Tools**") |
|
st.session_state.selected_tools = st.multiselect( |
|
label="agent tools", |
|
options=tool_options, |
|
label_visibility="collapsed", |
|
) |
|
if "Search" not in tool_options: |
|
st.write( |
|
"<small>To search the internet, obtain your Bing Subscription " |
|
"Key [here](https://portal.azure.com/) and enter it in the " |
|
"sidebar. Once entered, 'Search' will be displayed in the " |
|
"list of tools.</small>", |
|
unsafe_allow_html=True, |
|
) |
|
if "Retrieval" in st.session_state.selected_tools: |
|
|
|
get_retriever() |
|
if st.session_state.retriever_tool is not None: |
|
tool_dictionary["Retrieval"] = st.session_state.retriever_tool |
|
else: |
|
st.session_state.selected_tools.remove("Retrieval") |
|
|
|
st.session_state.tools = [ |
|
tool_dictionary[key] for key in st.session_state.selected_tools |
|
] |
|
|
|
|
|
def set_debate() -> None: |
|
""" |
|
Prepare the agents for the debate by setting the topic, names, |
|
descriptions of the participants, and the questions for the debate, |
|
uploading speaker WAVs, and allowing the use of default voices. |
|
""" |
|
st.write("**Upload Speaker WAV Files** (Optional)") |
|
|
|
st.session_state.positive_speaker_wav = st.file_uploader( |
|
label="Upload WAV for Positive Debater (Optional)", |
|
type=["wav"], |
|
key="positive_speaker", |
|
) |
|
st.session_state.negative_speaker_wav = st.file_uploader( |
|
label="Upload WAV for Negative Debater (Optional)", |
|
type=["wav"], |
|
key="negative_speaker", |
|
) |
|
|
|
st.write("**Topic of the debate**") |
|
topic = st.text_input( |
|
label="topic of the debate", |
|
placeholder="Enter your topic", |
|
value=st.session_state.topic, |
|
label_visibility="collapsed", |
|
) |
|
st.session_state.topic = topic.strip() |
|
st.write( |
|
"**Language** " |
|
"<small>used by the debaters</small>", |
|
unsafe_allow_html=True |
|
) |
|
st.session_state.language = st.radio( |
|
label="language", |
|
options=("English", "Hindi", "Spanish", "French", "Chinese", "Korean", "Japanese"), |
|
label_visibility="collapsed", |
|
index=1, |
|
horizontal=True |
|
) |
|
st.write("**Model**") |
|
st.session_state.model = st.radio( |
|
label="Model", |
|
options=("gpt-4o-mini", "gpt-4o"), |
|
label_visibility="collapsed", |
|
horizontal=True, |
|
index=1, |
|
) |
|
|
|
|
|
set_tools() |
|
|
|
left, right = st.columns(2) |
|
left.write("**Word limit for question suggestions** (≥ 10)") |
|
|
|
description_word_limit = left.number_input( |
|
label="description_word_limit", |
|
min_value=10, |
|
max_value=500, |
|
value=20, |
|
step=10, |
|
label_visibility="collapsed" |
|
) |
|
|
|
right.write("**Word limit for each debate response** (≥ 50)") |
|
|
|
st.session_state.word_limit = right.number_input( |
|
label="answer_word_limit", |
|
min_value=50, |
|
max_value=2000, |
|
value=100, |
|
step=50, |
|
label_visibility="collapsed" |
|
) |
|
|
|
if st.button("Suggest names for the debaters"): |
|
st.session_state.positive, st.session_state.negative = ( |
|
get_participant_names(topic) |
|
) |
|
|
|
left, right = st.columns(2) |
|
left.write("**Name for the positive**") |
|
positive = left.text_input( |
|
label="name of the positive", |
|
value=st.session_state.positive, |
|
label_visibility="collapsed" |
|
) |
|
st.session_state.positive = positive |
|
|
|
right.write("**Name for the negative**") |
|
negative = right.text_input( |
|
label="name of the negative", |
|
value=st.session_state.negative, |
|
label_visibility="collapsed" |
|
) |
|
st.session_state.negative = negative |
|
|
|
st.session_state.names = { |
|
positive: st.session_state.tools, |
|
negative: st.session_state.tools, |
|
} |
|
conversation_description = ( |
|
"Here is the topic of conversation: " |
|
f"{append_period(topic)}\nThe participants are: " |
|
f"{' and '.join(st.session_state.names.keys())}." |
|
) |
|
|
|
agent_descriptions, agent_system_messages = {}, {} |
|
|
|
if positive and negative: |
|
if st.button("Suggest descriptions for the debaters"): |
|
for name in st.session_state.names.keys(): |
|
st.session_state.agent_descriptions[name] = ( |
|
generate_agent_description( |
|
name, |
|
conversation_description, |
|
st.session_state.language, |
|
description_word_limit |
|
) |
|
) |
|
|
|
for name in st.session_state.names.keys(): |
|
st.write(f"**Description for {name}**") |
|
agent_descriptions[name] = st.text_area( |
|
label=f"description for {name}", |
|
value=st.session_state.agent_descriptions.get(name, ""), |
|
label_visibility="collapsed" |
|
) |
|
st.session_state.agent_descriptions[name] = agent_descriptions[name] |
|
|
|
keys_to_delete = [ |
|
key for key in st.session_state.agent_descriptions |
|
if key not in st.session_state.names |
|
] |
|
for key in keys_to_delete: |
|
del st.session_state.agent_descriptions[key] |
|
|
|
for name in st.session_state.names.keys(): |
|
agent_system_messages[name] = generate_system_message( |
|
name, |
|
conversation_description, |
|
agent_descriptions[name], |
|
st.session_state.language, |
|
st.session_state.word_limit |
|
) |
|
|
|
if st.button("Suggest questions for the debaters"): |
|
topic_specifier_prompt = [ |
|
SystemMessage(content="You can make a topic more specific."), |
|
HumanMessage( |
|
content=( |
|
"Here is the topic of conversation: " |
|
f"{append_period(topic)}\n" |
|
"You are the moderator.\n" |
|
"Please make the topic more specific.\n" |
|
"Please reply with the specified quest in " |
|
f"{description_word_limit} words or less in " |
|
f"{st.session_state.language}.\n" |
|
"Speak directly to the participants: " |
|
f"{*st.session_state.names,}.\n" |
|
"Do not add anything else." |
|
) |
|
), |
|
] |
|
topic_specifier_llm = ChatOpenAI( |
|
model=st.session_state.model, temperature=1.0 |
|
) |
|
st.session_state.specified_topic = topic_specifier_llm.invoke( |
|
topic_specifier_prompt |
|
).content |
|
|
|
st.write("**Questions for the debaters**") |
|
specified_topic = st.text_area( |
|
label="questions for the debaters", |
|
value=st.session_state.specified_topic, |
|
label_visibility="collapsed", |
|
) |
|
st.session_state.specified_topic = specified_topic |
|
|
|
if st.session_state.specified_topic: |
|
if st.button("Prepare the debate"): |
|
agent_llm = ChatOpenAI( |
|
model=st.session_state.model, temperature=0.2 |
|
) |
|
agents = [ |
|
DialogueAgent( |
|
name=name, |
|
system_message=SystemMessage(content=system_message), |
|
llm=agent_llm, |
|
tools=tools, |
|
) |
|
for (name, tools), system_message in zip( |
|
st.session_state.names.items(), |
|
agent_system_messages.values() |
|
) |
|
] |
|
st.session_state.simulator = DialogueSimulator( |
|
agents=agents, selection_function=select_next_speaker |
|
) |
|
st.session_state.simulator.reset() |
|
st.session_state.simulator.inject("Moderator", specified_topic) |
|
st.session_state.new_debate = False |
|
st.rerun() |
|
|
|
|
|
def print_topic_debaters_questions() -> str: |
|
""" |
|
Print the topic, the names and descriptions of the participants, |
|
and questions for the debate. |
|
""" |
|
|
|
st.write("**Topic of the debate**") |
|
st.info(f"**{st.session_state.topic}**") |
|
st.write( |
|
"**Name for the positive**$\:$: " |
|
f"$~$:blue[{st.session_state.positive}]" |
|
) |
|
st.write( |
|
"**Name for the negative**: " |
|
f"$~$:blue[{st.session_state.negative}]" |
|
) |
|
agent_descriptions = st.session_state.agent_descriptions |
|
dict_name = "agent_descriptions" |
|
for name in st.session_state.names.keys(): |
|
st.write(f"**Description for {name}**") |
|
st.info(agent_descriptions[name]) |
|
|
|
st.write("**Moderator**: Here are the questions for the debaters") |
|
st.info(st.session_state.specified_topic) |
|
|
|
headers = ( |
|
"Topic of the debate: " |
|
f"{st.session_state.topic}\n\n" |
|
"Name for the positive: " |
|
f"{st.session_state.positive}\n" |
|
"Name for the negative: " |
|
f"{st.session_state.negative}\n\n" |
|
) |
|
|
|
if agent_descriptions[st.session_state.positive]: |
|
headers += ( |
|
f"Description for {st.session_state.positive}:\n" |
|
f"{locals()[dict_name][st.session_state.positive]}\n\n" |
|
) |
|
if agent_descriptions[st.session_state.negative]: |
|
headers += ( |
|
f"Description for {st.session_state.negative}:\n" |
|
f"{locals()[dict_name][st.session_state.negative]}\n\n" |
|
) |
|
|
|
headers += f"Moderator: {st.session_state.specified_topic}\n\n" |
|
return headers |
|
|
|
|
|
def conclude_debate() -> None: |
|
""" |
|
End the debate by providing a summary of the points raised by |
|
each participant and making a concluding remark. Add this conclusion |
|
to the list of conversations. |
|
""" |
|
|
|
word_limit = 2 * st.session_state.word_limit |
|
moderator_prompt = [ |
|
SystemMessage( |
|
content=( |
|
"You are the Moderator. " |
|
"Your goal is to provide a comprehensive summary " |
|
"highlighting the key points raised by each participant, " |
|
"and then to conclude the debate in a productive manner. " |
|
"If there is a clear standout in terms of being more " |
|
"persuasive or convincing, mention this in your conclusion." |
|
) |
|
), |
|
HumanMessage( |
|
content=( |
|
f"Answer in {word_limit} words or less " |
|
f"in {st.session_state.language}.\n\n" |
|
"Here is the complete conversation.\n\n" |
|
f"{st.session_state.complete_conversations}\n\n" |
|
"Moderator: " |
|
) |
|
), |
|
] |
|
moderator_llm = ChatOpenAI( |
|
model=st.session_state.model, temperature=0.2 |
|
) |
|
with st.spinner("Moderator is thinking..."): |
|
st.session_state.conclusions = moderator_llm.invoke( |
|
moderator_prompt |
|
).content |
|
|
|
st.session_state.conversations.append( |
|
f"Moderator: {st.session_state.conclusions}" |
|
) |
|
st.session_state.conversations4print.append( |
|
f"**Moderator**: {st.session_state.conclusions}" |
|
) |
|
|
|
|
|
def multi_agent_debate(): |
|
""" |
|
Let two agents, equipped with tools such as bing search, arxiv, |
|
and retriever, debate on a given topic. The debate can be concluded |
|
with a remark and be downloaded. |
|
""" |
|
|
|
page_title = "Multi-lingual Multi-Agent Debate" |
|
page_icon = "📚" |
|
|
|
st.set_page_config( |
|
page_title=page_title, |
|
page_icon=page_icon, |
|
layout="centered" |
|
) |
|
|
|
|
|
st.write(f"## {page_icon} $\,${page_title}") |
|
|
|
|
|
st.image("./files/image-3.png", caption=" ", use_container_width=True) |
|
st.info( |
|
""" |
|
**Acknowledgment**: This project is inspired by [Twy's Work](https://github.com/twy80/Multi_Agent_Debate). |
|
""" |
|
) |
|
|
|
|
|
initialize_session_state_variables() |
|
|
|
|
|
with st.sidebar: |
|
st.write("### API Key Status") |
|
openai_key_status = "✔️ Available" if os.getenv("OPENAI_API_KEY") else "❌ Missing" |
|
bing_key_status = "✔️ Available" if os.getenv("BING_SUBSCRIPTION_KEY") else "❌ Missing" |
|
|
|
st.write(f"**OpenAI Key**: {openai_key_status}") |
|
st.write(f"**Bing Key**: {bing_key_status}") |
|
|
|
|
|
if not os.getenv("OPENAI_API_KEY"): |
|
st.error("OpenAI Key is required for this app to function properly.") |
|
|
|
|
|
if st.session_state.new_debate: |
|
set_debate() |
|
else: |
|
with st.sidebar: |
|
st.write("") |
|
st.write(f"**Model**: :blue[{st.session_state.model}]") |
|
st.write(f"**Language**: :blue[{st.session_state.language}]") |
|
st.write(f"**Word limit**: :blue[{st.session_state.word_limit}]") |
|
|
|
if st.session_state.selected_tools: |
|
used_tools = ( |
|
f":blue[{', '.join(st.session_state.selected_tools)}]" |
|
) |
|
if len(st.session_state.selected_tools) == 1: |
|
st.write(f"**Tool**: {used_tools}") |
|
else: |
|
st.write(f"**Tools**: {used_tools}") |
|
else: |
|
st.write(f"**Tool**: :blue[None]") |
|
|
|
headers = print_topic_debaters_questions() |
|
st.session_state.complete_conversations = ( |
|
headers + "\n\n".join(st.session_state.conversations) |
|
) |
|
|
|
if st.session_state.conversations: |
|
label_debate = "$\,$Continue the debate$\,$" |
|
label_no_of_rounds = "Number of additional rounds" |
|
value_no_of_rounds = 1 |
|
else: |
|
label_debate = "$~~~\,$Start the debate$~~~\,$" |
|
label_no_of_rounds = "Number of rounds in this debate" |
|
value_no_of_rounds = 5 |
|
|
|
st.write("") |
|
for message in st.session_state.conversations4print: |
|
st.write(message) |
|
|
|
if not st.session_state.conclusions: |
|
st.write(f"**{label_no_of_rounds}**") |
|
c1, _, _ = st.columns(3) |
|
no_of_rounds = c1.number_input( |
|
label=f"{label_no_of_rounds}", |
|
min_value=1, |
|
max_value=10, |
|
value=value_no_of_rounds, |
|
step=1, |
|
label_visibility="collapsed", |
|
) |
|
|
|
if st.session_state.conversations and not st.session_state.conclusions: |
|
st.write("**Facilitative comments by the (human) moderator** (Optional)") |
|
facilitative_comments = st.text_input( |
|
label="facilitative_comments", |
|
value="", |
|
key="comments" + str(st.session_state.comments_key), |
|
label_visibility="collapsed", |
|
) |
|
if facilitative_comments: |
|
st.session_state.simulator.inject("Moderator", facilitative_comments) |
|
st.session_state.conversations.append( |
|
f"Moderator: {facilitative_comments}" |
|
) |
|
st.session_state.conversations4print.append( |
|
f"**Moderator**: {facilitative_comments}" |
|
) |
|
st.session_state.comments_key += 1 |
|
|
|
left, right = st.columns(2) |
|
|
|
if not st.session_state.conclusions: |
|
if left.button(f"{label_debate}"): |
|
run_simulator(no_of_rounds, st.session_state.simulator) |
|
st.rerun() |
|
if st.session_state.conversations: |
|
if right.button("Conclude the debate$\,$"): |
|
conclude_debate() |
|
st.rerun() |
|
else: |
|
if right.button("$~\:$Back to the setting$~\:$"): |
|
st.session_state.new_debate = True |
|
st.rerun() |
|
|
|
left.download_button( |
|
label="Download the debate", |
|
data=st.session_state.complete_conversations, |
|
file_name="multi_agent_debate.txt", |
|
mime="text/plain" |
|
) |
|
right.button( |
|
label="$~~\,\:\!$Reset the debate$~~\,\,$", |
|
on_click=reset_debate |
|
) |
|
|
|
if __name__ == "__main__": |
|
multi_agent_debate() |
|
|