File size: 4,592 Bytes
955c8a8
 
 
fc5a865
e07cf78
 
 
 
fc5a865
 
 
955c8a8
fc5a865
955c8a8
 
fc5a865
955c8a8
 
 
 
 
 
 
 
 
 
fc5a865
 
 
955c8a8
fc5a865
955c8a8
e07cf78
944d9b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e07cf78
 
 
955c8a8
 
e07cf78
955c8a8
 
944d9b4
 
 
 
 
 
e07cf78
955c8a8
 
fc5a865
955c8a8
 
944d9b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import transformers
import streamlit as st

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import json

with open('testbook.json') as f:
    test_book = json.load(f)

tokenizer = AutoTokenizer.from_pretrained("UNIST-Eunchan/bart-dnc-booksum")

def load_model(model_name):
    model = AutoModelForSeq2SeqLM.from_pretrained("UNIST-Eunchan/bart-dnc-booksum")
    return model

model = load_model("UNIST-Eunchan/bart-dnc-booksum")

def infer(input_ids, max_length, temperature, top_k, top_p):

    output_sequences = model.generate(
        input_ids=input_ids,
        max_length=max_length,
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        do_sample=True,
        num_return_sequences=1,
        num_beams=4,
        no_repeat_ngram_size=2
    )
    
    return output_sequences


def chunking(book_text):
    segments = [] 
    #sentences, token_lens
    current_segment = ""    
    total_token_lens = 0
    for i in range(len(sentences)):
        
        if total_token_lens < 512:
            total_token_lens += token_lens[i]
            current_segment += (sentences[i] + " ")
            
        elif total_token_lens > 768:
            segments.append(current_segment)
            current_segment = sentences[i]
            total_token_lens = token_lens[i]
    
        else:
            #make next_pseudo_segment
            next_pseudo_segment = ""
            next_token_len = 0
            for t in range(30):
                if (i+t < len(sentences)) and (next_token_len + token_lens[i+t] < 512):
                    next_token_len += token_lens[i+t] 
                    next_pseudo_segment += sentences[i+t] 
                
            embs = model.encode([current_segment, next_pseudo_segment, sentences[i]]) # current, next, sent
            if cos_similarity(embs[1],embs[2]) > cos_similarity(embs[0],embs[2]):
                segments.append(current_segment)
                current_segment = sentences[i]
                total_token_lens = token_lens[i]
            else: 
                total_token_lens += token_lens[i]
                current_segment += (sentences[i] + " ")

    return segments


chunked_segments = chunking(test_book[0]['book'])

'''
'''

#prompts
st.title("Book Summarization πŸ“š")
st.write("The almighty king of text generation, GPT-2 comes in four available sizes, only three of which have been publicly made available. Feared for its fake news generation capabilities, it currently stands as the most syntactically coherent model. A direct successor to the original GPT, it reinforces the already established pre-training/fine-tuning killer duo. From the paper: Language Models are Unsupervised Multitask Learners by Alec Radford, Jeffrey Wu, Rewon Child, David Luan, Dario Amodei and Ilya Sutskever.")

book_index = st.sidebar.slider("Select Book Example", value = 0,min_value = 0, max_value=4)

_book = test_book[book_index]['book']
chunked_segments = chunking(_book)

sent = st.text_area("Text", _book, height = 550)
max_length = st.sidebar.slider("Max Length", value = 512,min_value = 10, max_value=1024)
temperature = st.sidebar.slider("Temperature", value = 1.0, min_value = 0.0, max_value=1.0, step=0.05)
top_k = st.sidebar.slider("Top-k", min_value = 0, max_value=5, value = 0)
top_p = st.sidebar.slider("Top-p", min_value = 0.0, max_value=1.0, step = 0.05, value = 0.92)


for segment in range(len(chunked_segments)):
        
    encoded_prompt = tokenizer.encode(segment, add_special_tokens=False, return_tensors="pt")
    if encoded_prompt.size()[-1] == 0:
        input_ids = None
    else:
        input_ids = encoded_prompt
    
    
    output_sequences = infer(input_ids, max_length, temperature, top_k, top_p)
    
    
    for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
        print(f"=== GENERATED SEQUENCE {generated_sequence_idx + 1} ===")
        generated_sequences = generated_sequence.tolist()
    
        # Decode text
        text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
    
        # Remove all text after the stop token
        #text = text[: text.find(args.stop_token) if args.stop_token else None]
    
        # Add the prompt at the beginning of the sequence. Remove the excess text that was used for pre-processing
        total_sequence = (
            sent + text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :]
        )
    
        generated_sequences.append(total_sequence)
        print(total_sequence)
    
    
    st.write(generated_sequences[-1])