File size: 4,287 Bytes
b092c58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import argparse
from dataclasses import asdict
import json
import os
import streamlit as st

from data_driven_characters.character import get_character_definition
from data_driven_characters.corpus import (
    get_corpus_summaries,
    load_docs,
)

from data_driven_characters.chatbots import (
    SummaryChatBot,
    RetrievalChatBot,
    SummaryRetrievalChatBot,
)
from data_driven_characters.interfaces import CommandLine, Streamlit

OUTPUT_ROOT = "output"


def create_chatbot(corpus, character_name, chatbot_type, retrieval_docs, summary_type):
    # logging
    corpus_name = os.path.splitext(os.path.basename(corpus))[0]
    output_dir = f"{OUTPUT_ROOT}/{corpus_name}/summarytype_{summary_type}"
    os.makedirs(output_dir, exist_ok=True)
    summaries_dir = f"{output_dir}/summaries"
    character_definitions_dir = f"{output_dir}/character_definitions"
    os.makedirs(character_definitions_dir, exist_ok=True)

    # load docs
    docs = load_docs(corpus_path=corpus, chunk_size=2048, chunk_overlap=64)

    # generate summaries
    corpus_summaries = get_corpus_summaries(
        docs=docs, summary_type=summary_type, cache_dir=summaries_dir
    )

    # get character definition
    character_definition = get_character_definition(
        name=character_name,
        corpus_summaries=corpus_summaries,
        cache_dir=character_definitions_dir,
    )
    print(json.dumps(asdict(character_definition), indent=4))

    # construct retrieval documents
    if retrieval_docs == "raw":
        documents = [
            doc.page_content
            for doc in load_docs(corpus_path=corpus, chunk_size=256, chunk_overlap=16)
        ]
    elif retrieval_docs == "summarized":
        documents = corpus_summaries
    else:
        raise ValueError(f"Unknown retrieval docs type: {retrieval_docs}")

    # initialize chatbot
    if chatbot_type == "summary":
        chatbot = SummaryChatBot(character_definition=character_definition)
    elif chatbot_type == "retrieval":
        chatbot = RetrievalChatBot(
            character_definition=character_definition,
            documents=documents,
        )
    elif chatbot_type == "summary_retrieval":
        chatbot = SummaryRetrievalChatBot(
            character_definition=character_definition,
            documents=documents,
        )
    else:
        raise ValueError(f"Unknown chatbot type: {chatbot_type}")
    return chatbot


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--corpus", type=str, default="data/everything_everywhere_all_at_once.txt"
    )
    parser.add_argument("--character_name", type=str, default="Evelyn")
    parser.add_argument(
        "--chatbot_type",
        type=str,
        default="summary_retrieval",
        choices=["summary", "retrieval", "summary_retrieval"],
    )
    parser.add_argument(
        "--summary_type",
        type=str,
        default="map_reduce",
        choices=["map_reduce", "refine"],
    )
    parser.add_argument(
        "--retrieval_docs",
        type=str,
        default="summarized",
        choices=["raw", "summarized"],
    )
    parser.add_argument(
        "--interface", type=str, default="cli", choices=["cli", "streamlit"]
    )
    args = parser.parse_args()

    if args.interface == "cli":
        chatbot = create_chatbot(
            args.corpus,
            args.character_name,
            args.chatbot_type,
            args.retrieval_docs,
            args.summary_type,
        )
        app = CommandLine(chatbot=chatbot)
    elif args.interface == "streamlit":
        chatbot = st.cache_resource(create_chatbot)(
            args.corpus,
            args.character_name,
            args.chatbot_type,
            args.retrieval_docs,
            args.summary_type,
        )
        st.title("Data Driven Characters")
        st.write("Create your own character chatbots, grounded in existing corpora.")
        st.divider()
        st.markdown(f"**chatbot type**: *{args.chatbot_type}*")
        if "retrieval" in args.chatbot_type:
            st.markdown(f"**retrieving from**: *{args.retrieval_docs} corpus*")
        app = Streamlit(chatbot=chatbot)
    else:
        raise ValueError(f"Unknown interface: {args.interface}")
    app.run()


if __name__ == "__main__":
    main()