File size: 3,368 Bytes
b092c58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import json
import os

from langchain import PromptTemplate, LLMChain
from langchain.chat_models import ChatOpenAI
from langchain.chains.summarize import load_summarize_chain
from langchain.text_splitter import RecursiveCharacterTextSplitter

from data_driven_characters.constants import VERBOSE


def generate_docs(corpus, chunk_size, chunk_overlap):
    """Generate docs from a corpus."""
    text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
        chunk_size=chunk_size, chunk_overlap=chunk_overlap
    )
    docs = text_splitter.create_documents([corpus])
    return docs


def load_docs(corpus_path, chunk_size, chunk_overlap):
    """Load the corpus and split it into chunks."""

    with open(corpus_path) as f:
        corpus = f.read()
    docs = generate_docs(corpus, chunk_size, chunk_overlap)
    return docs


def generate_corpus_summaries(docs, summary_type="map_reduce"):
    """Generate summaries of the story."""
    GPT3 = ChatOpenAI(model_name="gpt-3.5-turbo")
    chain = load_summarize_chain(
        GPT3, chain_type=summary_type, return_intermediate_steps=True, verbose=True
    )
    summary = chain({"input_documents": docs}, return_only_outputs=True)
    intermediate_summaries = summary["intermediate_steps"]
    return intermediate_summaries


def get_corpus_summaries(docs, summary_type, cache_dir, force_refresh=False):
    """Load the corpus summaries from cache or generate them."""
    if not os.path.exists(cache_dir) or force_refresh:
        os.makedirs(cache_dir, exist_ok=True)
        if VERBOSE:
            print("Summaries do not exist. Generating summaries.")
        intermediate_summaries = generate_corpus_summaries(docs, summary_type)
        for i, intermediate_summary in enumerate(intermediate_summaries):
            with open(os.path.join(cache_dir, f"summary_{i}.txt"), "w") as f:
                f.write(intermediate_summary)
    else:
        if VERBOSE:
            print("Summaries already exist. Loading summaries.")
        intermediate_summaries = []
        for i in range(len(os.listdir(cache_dir))):
            with open(os.path.join(cache_dir, f"summary_{i}.txt")) as f:
                intermediate_summaries.append(f.read())
    return intermediate_summaries


def generate_characters(corpus_summaries, num_characters):
    """Get a list of characters from a list of summaries."""
    GPT4 = ChatOpenAI(model_name="gpt-3.5-turbo")
    characters_prompt_template = """Consider the following corpus.
    ---
    {corpus_summaries}
    ---
    Give a line-separated list of all the characters, ordered by importance, without punctuation.
    """
    characters = LLMChain(
        llm=GPT4, prompt=PromptTemplate.from_template(characters_prompt_template)
    ).run(corpus_summaries="\n\n".join(corpus_summaries))
    # remove (, ), and " for each element of list
    return characters.split("\n")[:num_characters]


def get_characters(corpus_summaries, num_characters, cache_dir, force_refresh=False):
    cache_file = os.path.join(cache_dir, "characters.json")
    if not os.path.exists(cache_file) or force_refresh:
        characters = generate_characters(corpus_summaries, num_characters)
        with open(cache_file, "w") as f:
            json.dump(characters, f)
    else:
        with open(cache_file, "r") as f:
            characters = json.load(f)
    return characters