rodrigomasini commited on
Commit
87961c6
·
1 Parent(s): 04e7d93

Create recurreentgpt.py

Browse files
Files changed (1) hide show
  1. recurreentgpt.py +135 -0
recurreentgpt.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils import get_content_between_a_b, get_api_response
2
+ import torch
3
+
4
+ import random
5
+
6
+ from sentence_transformers import util
7
+
8
+
9
+ class RecurrentGPT:
10
+
11
+ def __init__(self, input, short_memory, long_memory, memory_index, embedder):
12
+ self.input = input
13
+ self.short_memory = short_memory
14
+ self.long_memory = long_memory
15
+ self.embedder = embedder
16
+ if self.long_memory and not memory_index:
17
+ self.memory_index = self.embedder.encode(
18
+ self.long_memory, convert_to_tensor=True)
19
+ self.output = {}
20
+
21
+ def prepare_input(self, new_character_prob=0.1, top_k=2):
22
+
23
+ input_paragraph = self.input["output_paragraph"]
24
+ input_instruction = self.input["output_instruction"]
25
+
26
+ instruction_embedding = self.embedder.encode(
27
+ input_instruction, convert_to_tensor=True)
28
+
29
+ # get the top 3 most similar paragraphs from memory
30
+
31
+ memory_scores = util.cos_sim(
32
+ instruction_embedding, self.memory_index)[0]
33
+ top_k_idx = torch.topk(memory_scores, k=top_k)[1]
34
+ top_k_memory = [self.long_memory[idx] for idx in top_k_idx]
35
+ # combine the top 3 paragraphs
36
+ input_long_term_memory = '\n'.join(
37
+ [f"Related Paragraphs {i+1} :" + selected_memory for i, selected_memory in enumerate(top_k_memory)])
38
+ # randomly decide if a new character should be introduced
39
+ if random.random() < new_character_prob:
40
+ new_character_prompt = f"If it is reasonable, you can introduce a new character in the output paragrah and add it into the memory."
41
+ else:
42
+ new_character_prompt = ""
43
+
44
+ input_text = f"""I need you to help me write a novel. Now I give you a memory (a brief summary) of 400 words, you should use it to store the key content of what has been written so that you can keep track of very long context. For each time, I will give you your current memory (a brief summary of previous stories. You should use it to store the key content of what has been written so that you can keep track of very long context), the previously written paragraph, and instructions on what to write in the next paragraph.
45
+ I need you to write:
46
+ 1. Output Paragraph: the next paragraph of the novel. The output paragraph should contain around 20 sentences and should follow the input instructions.
47
+ 2. Output Memory: The updated memory. You should first explain which sentences in the input memory are no longer necessary and why, and then explain what needs to be added into the memory and why. After that you should write the updated memory. The updated memory should be similar to the input memory except the parts you previously thought that should be deleted or added. The updated memory should only store key information. The updated memory should never exceed 20 sentences!
48
+ 3. Output Instruction: instructions of what to write next (after what you have written). You should output 3 different instructions, each is a possible interesting continuation of the story. Each output instruction should contain around 5 sentences
49
+ Here are the inputs:
50
+
51
+ Input Memory:
52
+ {self.short_memory}
53
+
54
+ Input Paragraph:
55
+ {input_paragraph}
56
+
57
+ Input Instruction:
58
+ {input_instruction}
59
+
60
+ Input Related Paragraphs:
61
+ {input_long_term_memory}
62
+
63
+ Now start writing, organize your output by strictly following the output format as below:
64
+ Output Paragraph:
65
+ <string of output paragraph>, around 20 sentences.
66
+
67
+ Output Memory:
68
+ Rational: <string that explain how to update the memory>;
69
+ Updated Memory: <string of updated memory>, around 10 to 20 sentences
70
+
71
+ Output Instruction:
72
+ Instruction 1: <content for instruction 1>, around 5 sentences
73
+ Instruction 2: <content for instruction 2>, around 5 sentences
74
+ Instruction 3: <content for instruction 3>, around 5 sentences
75
+
76
+ Very important!! The updated memory should only store key information. The updated memory should never contain over 500 words!
77
+ Finally, remember that you are writing a novel. Write like a novelist and do not move too fast when writing the output instructions for the next paragraph. Remember that the chapter will contain over 10 paragraphs and the novel will contain over 100 chapters. And this is just the beginning. Just write some interesting staffs that will happen next. Also, think about what plot can be attractive for common readers when writing output instructions.
78
+
79
+ Very Important:
80
+ You should first explain which sentences in the input memory are no longer necessary and why, and then explain what needs to be added into the memory and why. After that, you start rewrite the input memory to get the updated memory.
81
+ {new_character_prompt}
82
+ """
83
+ return input_text
84
+
85
+ def parse_output(self, output):
86
+ try:
87
+ output_paragraph = get_content_between_a_b(
88
+ 'Output Paragraph:', 'Output Memory', output)
89
+ output_memory_updated = get_content_between_a_b(
90
+ 'Updated Memory:', 'Output Instruction:', output)
91
+ self.short_memory = output_memory_updated
92
+ ins_1 = get_content_between_a_b(
93
+ 'Instruction 1:', 'Instruction 2', output)
94
+ ins_2 = get_content_between_a_b(
95
+ 'Instruction 2:', 'Instruction 3', output)
96
+ lines = output.splitlines()
97
+ # content of Instruction 3 may be in the same line with I3 or in the next line
98
+ if lines[-1] != '\n' and lines[-1].startswith('Instruction 3'):
99
+ ins_3 = lines[-1][len("Instruction 3:"):]
100
+ elif lines[-1] != '\n':
101
+ ins_3 = lines[-1]
102
+
103
+ output_instructions = [ins_1, ins_2, ins_3]
104
+ assert len(output_instructions) == 3
105
+
106
+ output = {
107
+ "input_paragraph": self.input["output_paragraph"],
108
+ "output_memory": output_memory_updated, # feed to human
109
+ "output_paragraph": output_paragraph,
110
+ "output_instruction": [instruction.strip() for instruction in output_instructions]
111
+ }
112
+
113
+ return output
114
+ except:
115
+ return None
116
+
117
+ def step(self, response_file=None):
118
+
119
+ prompt = self.prepare_input()
120
+
121
+ print(prompt+'\n'+'\n')
122
+
123
+ response = get_api_response(prompt)
124
+
125
+ self.output = self.parse_output(response)
126
+ while self.output == None:
127
+ response = get_api_response(prompt)
128
+ self.output = self.parse_output(response)
129
+ if response_file:
130
+ with open(response_file, 'a', encoding='utf-8') as f:
131
+ f.write(f"Writer's output here:\n{response}\n\n")
132
+
133
+ self.long_memory.append(self.input["output_paragraph"])
134
+ self.memory_index = self.embedder.encode(
135
+ self.long_memory, convert_to_tensor=True)