File size: 4,288 Bytes
955c8a8
 
b48fa68
fc5a865
e07cf78
 
 
 
fc5a865
 
 
955c8a8
fc5a865
955c8a8
 
fc5a865
955c8a8
 
 
 
 
 
 
 
 
 
fc5a865
 
 
955c8a8
fc5a865
955c8a8
e07cf78
944d9b4
 
b48fa68
944d9b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d872090
944d9b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e07cf78
 
955c8a8
 
e07cf78
955c8a8
 
944d9b4
 
 
 
 
d872090
e07cf78
955c8a8
 
fc5a865
955c8a8
d872090
 
 
 
 
 
 
 
61aa0aa
 
d872090
 
 
 
 
 
 
 
 
 
 
 
 
955c8a8
944d9b4
 
d872090
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import transformers
import streamlit as st
from nltk import sent_tokenize
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import json

with open('testbook.json') as f:
    test_book = json.load(f)

tokenizer = AutoTokenizer.from_pretrained("UNIST-Eunchan/bart-dnc-booksum")

def load_model(model_name):
    model = AutoModelForSeq2SeqLM.from_pretrained("UNIST-Eunchan/bart-dnc-booksum")
    return model

model = load_model("UNIST-Eunchan/bart-dnc-booksum")

def infer(input_ids, max_length, temperature, top_k, top_p):

    output_sequences = model.generate(
        input_ids=input_ids,
        max_length=max_length,
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        do_sample=True,
        num_return_sequences=1,
        num_beams=4,
        no_repeat_ngram_size=2
    )
    
    return output_sequences


def chunking(book_text):
    sentences = sent_tokenize(book_text)
    segments = [] 
    #sentences, token_lens
    current_segment = ""    
    total_token_lens = 0
    for i in range(len(sentences)):
        
        if total_token_lens < 512:
            total_token_lens += token_lens[i]
            current_segment += (sentences[i] + " ")
            
        elif total_token_lens > 768:
            segments.append(current_segment)
            current_segment = sentences[i]
            total_token_lens = token_lens[i]
    
        else:
            #make next_pseudo_segment
            next_pseudo_segment = ""
            next_token_len = 0
            for t in range(10):
                if (i+t < len(sentences)) and (next_token_len + token_lens[i+t] < 512):
                    next_token_len += token_lens[i+t] 
                    next_pseudo_segment += sentences[i+t] 
                
            embs = model.encode([current_segment, next_pseudo_segment, sentences[i]]) # current, next, sent
            if cos_similarity(embs[1],embs[2]) > cos_similarity(embs[0],embs[2]):
                segments.append(current_segment)
                current_segment = sentences[i]
                total_token_lens = token_lens[i]
            else: 
                total_token_lens += token_lens[i]
                current_segment += (sentences[i] + " ")

    return segments


'''
'''

#prompts
st.title("Book Summarization πŸ“š")
st.write("The almighty king of text generation, GPT-2 comes in four available sizes, only three of which have been publicly made available. Feared for its fake news generation capabilities, it currently stands as the most syntactically coherent model. A direct successor to the original GPT, it reinforces the already established pre-training/fine-tuning killer duo. From the paper: Language Models are Unsupervised Multitask Learners by Alec Radford, Jeffrey Wu, Rewon Child, David Luan, Dario Amodei and Ilya Sutskever.")

book_index = st.sidebar.slider("Select Book Example", value = 0,min_value = 0, max_value=4)

_book = test_book[book_index]['book']
chunked_segments = chunking(_book)

sent = st.text_area("Text", _book[:512], height = 550)
max_length = st.sidebar.slider("Max Length", value = 512,min_value = 10, max_value=1024)
temperature = st.sidebar.slider("Temperature", value = 1.0, min_value = 0.0, max_value=1.0, step=0.05)
top_k = st.sidebar.slider("Top-k", min_value = 0, max_value=5, value = 0)
top_p = st.sidebar.slider("Top-p", min_value = 0.0, max_value=1.0, step = 0.05, value = 0.92)

def generate_output(test_samples):
    inputs = tokenizer(
        test_samples,
        padding=max_length,
        truncation=True,
        max_length=1024,
        return_tensors="pt",
    )
    input_ids = inputs.input_ids
    attention_mask = inputs.attention_mask
    outputs = model.generate(input_ids,
                             max_length = 256,
                             min_length=32,
                             top_p = 0.92,
                             num_beams=5,
                             no_repeat_ngram_size=2,
                             attention_mask=attention_mask)
    output_str = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    return outputs, output_str


chunked_segments = chunking(test_book[0]['book'])


for segment in range(len(chunked_segments)):
        
    summaries = generate_output(segment)
    st.write(summaries[-1])