Spaces:
Sleeping
Sleeping
File size: 4,287 Bytes
b092c58 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import argparse
from dataclasses import asdict
import json
import os
import streamlit as st
from data_driven_characters.character import get_character_definition
from data_driven_characters.corpus import (
get_corpus_summaries,
load_docs,
)
from data_driven_characters.chatbots import (
SummaryChatBot,
RetrievalChatBot,
SummaryRetrievalChatBot,
)
from data_driven_characters.interfaces import CommandLine, Streamlit
OUTPUT_ROOT = "output"
def create_chatbot(corpus, character_name, chatbot_type, retrieval_docs, summary_type):
# logging
corpus_name = os.path.splitext(os.path.basename(corpus))[0]
output_dir = f"{OUTPUT_ROOT}/{corpus_name}/summarytype_{summary_type}"
os.makedirs(output_dir, exist_ok=True)
summaries_dir = f"{output_dir}/summaries"
character_definitions_dir = f"{output_dir}/character_definitions"
os.makedirs(character_definitions_dir, exist_ok=True)
# load docs
docs = load_docs(corpus_path=corpus, chunk_size=2048, chunk_overlap=64)
# generate summaries
corpus_summaries = get_corpus_summaries(
docs=docs, summary_type=summary_type, cache_dir=summaries_dir
)
# get character definition
character_definition = get_character_definition(
name=character_name,
corpus_summaries=corpus_summaries,
cache_dir=character_definitions_dir,
)
print(json.dumps(asdict(character_definition), indent=4))
# construct retrieval documents
if retrieval_docs == "raw":
documents = [
doc.page_content
for doc in load_docs(corpus_path=corpus, chunk_size=256, chunk_overlap=16)
]
elif retrieval_docs == "summarized":
documents = corpus_summaries
else:
raise ValueError(f"Unknown retrieval docs type: {retrieval_docs}")
# initialize chatbot
if chatbot_type == "summary":
chatbot = SummaryChatBot(character_definition=character_definition)
elif chatbot_type == "retrieval":
chatbot = RetrievalChatBot(
character_definition=character_definition,
documents=documents,
)
elif chatbot_type == "summary_retrieval":
chatbot = SummaryRetrievalChatBot(
character_definition=character_definition,
documents=documents,
)
else:
raise ValueError(f"Unknown chatbot type: {chatbot_type}")
return chatbot
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--corpus", type=str, default="data/everything_everywhere_all_at_once.txt"
)
parser.add_argument("--character_name", type=str, default="Evelyn")
parser.add_argument(
"--chatbot_type",
type=str,
default="summary_retrieval",
choices=["summary", "retrieval", "summary_retrieval"],
)
parser.add_argument(
"--summary_type",
type=str,
default="map_reduce",
choices=["map_reduce", "refine"],
)
parser.add_argument(
"--retrieval_docs",
type=str,
default="summarized",
choices=["raw", "summarized"],
)
parser.add_argument(
"--interface", type=str, default="cli", choices=["cli", "streamlit"]
)
args = parser.parse_args()
if args.interface == "cli":
chatbot = create_chatbot(
args.corpus,
args.character_name,
args.chatbot_type,
args.retrieval_docs,
args.summary_type,
)
app = CommandLine(chatbot=chatbot)
elif args.interface == "streamlit":
chatbot = st.cache_resource(create_chatbot)(
args.corpus,
args.character_name,
args.chatbot_type,
args.retrieval_docs,
args.summary_type,
)
st.title("Data Driven Characters")
st.write("Create your own character chatbots, grounded in existing corpora.")
st.divider()
st.markdown(f"**chatbot type**: *{args.chatbot_type}*")
if "retrieval" in args.chatbot_type:
st.markdown(f"**retrieving from**: *{args.retrieval_docs} corpus*")
app = Streamlit(chatbot=chatbot)
else:
raise ValueError(f"Unknown interface: {args.interface}")
app.run()
if __name__ == "__main__":
main()
|