push
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- app.py +161 -0
- model/paraphrase-MiniLM-L6-v2/.gitattributes +17 -0
- model/paraphrase-MiniLM-L6-v2/1_Pooling/config.json +7 -0
- model/paraphrase-MiniLM-L6-v2/README.md +108 -0
- model/paraphrase-MiniLM-L6-v2/config.json +24 -0
- model/paraphrase-MiniLM-L6-v2/config_sentence_transformers.json +7 -0
- model/paraphrase-MiniLM-L6-v2/model.safetensors +3 -0
- model/paraphrase-MiniLM-L6-v2/modules.json +14 -0
- model/paraphrase-MiniLM-L6-v2/onnx/model.onnx +3 -0
- model/paraphrase-MiniLM-L6-v2/onnx/model_O1.onnx +3 -0
- model/paraphrase-MiniLM-L6-v2/onnx/model_O2.onnx +3 -0
- model/paraphrase-MiniLM-L6-v2/onnx/model_O3.onnx +3 -0
- model/paraphrase-MiniLM-L6-v2/onnx/model_O4.onnx +3 -0
- model/paraphrase-MiniLM-L6-v2/onnx/model_qint8_arm64.onnx +3 -0
- model/paraphrase-MiniLM-L6-v2/onnx/model_qint8_avx512.onnx +3 -0
- model/paraphrase-MiniLM-L6-v2/onnx/model_qint8_avx512_vnni.onnx +3 -0
- model/paraphrase-MiniLM-L6-v2/onnx/model_quint8_avx2.onnx +3 -0
- model/paraphrase-MiniLM-L6-v2/openvino/openvino_model.bin +3 -0
- model/paraphrase-MiniLM-L6-v2/openvino/openvino_model.xml +0 -0
- model/paraphrase-MiniLM-L6-v2/openvino/openvino_model_qint8_quantized.bin +3 -0
- model/paraphrase-MiniLM-L6-v2/openvino/openvino_model_qint8_quantized.xml +0 -0
- model/paraphrase-MiniLM-L6-v2/pytorch_model.bin +3 -0
- model/paraphrase-MiniLM-L6-v2/sentence_bert_config.json +4 -0
- model/paraphrase-MiniLM-L6-v2/special_tokens_map.json +1 -0
- model/paraphrase-MiniLM-L6-v2/tf_model.h5 +3 -0
- model/paraphrase-MiniLM-L6-v2/tokenizer.json +0 -0
- model/paraphrase-MiniLM-L6-v2/tokenizer_config.json +1 -0
- model/paraphrase-MiniLM-L6-v2/vocab.txt +0 -0
- requirements.txt +22 -0
- src/DeepThink/__pycache__/__init__.cpython-311.pyc +0 -0
- src/DeepThink/__pycache__/engine.cpython-311.pyc +0 -0
- src/DeepThink/engine.py +285 -0
- src/DeepThink/modules/__pycache__/__init__.cpython-311.pyc +0 -0
- src/DeepThink/modules/__pycache__/article_generation.cpython-310.pyc +0 -0
- src/DeepThink/modules/__pycache__/article_generation.cpython-311.pyc +0 -0
- src/DeepThink/modules/__pycache__/article_polish.cpython-310.pyc +0 -0
- src/DeepThink/modules/__pycache__/article_polish.cpython-311.pyc +0 -0
- src/DeepThink/modules/__pycache__/interface.cpython-310.pyc +0 -0
- src/DeepThink/modules/__pycache__/interface.cpython-311.pyc +0 -0
- src/DeepThink/modules/__pycache__/mindmap.cpython-310.pyc +0 -0
- src/DeepThink/modules/__pycache__/mindmap.cpython-311.pyc +0 -0
- src/DeepThink/modules/__pycache__/outline_generation.cpython-310.pyc +0 -0
- src/DeepThink/modules/__pycache__/outline_generation.cpython-311.pyc +0 -0
- src/DeepThink/modules/__pycache__/retriever.cpython-311.pyc +0 -0
- src/DeepThink/modules/__pycache__/storm_dataclass.cpython-310.pyc +0 -0
- src/DeepThink/modules/__pycache__/storm_dataclass.cpython-311.pyc +0 -0
- src/DeepThink/modules/__pycache__/utils.cpython-310.pyc +0 -0
- src/DeepThink/modules/__pycache__/utils.cpython-311.pyc +0 -0
- src/DeepThink/modules/article_generation.py +523 -0
- src/DeepThink/modules/article_polish.py +417 -0
app.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import json
|
3 |
+
|
4 |
+
import pandas as pd
|
5 |
+
import streamlit as st
|
6 |
+
from dotenv import load_dotenv
|
7 |
+
from http import HTTPStatus
|
8 |
+
|
9 |
+
from src.lm import QwenModel
|
10 |
+
from src.rm import GoogleSearchAli_new
|
11 |
+
import sys
|
12 |
+
sys.path.append('./src/DeepThink/modules')
|
13 |
+
from mindmap import MindMap
|
14 |
+
from storm_dataclass import Article
|
15 |
+
|
16 |
+
from article_generation import ArticleGenerationModule
|
17 |
+
from article_polish import ArticlePolishingModule
|
18 |
+
from outline_generation import OutlineGenerationModule
|
19 |
+
|
20 |
+
import os
|
21 |
+
|
22 |
+
import subprocess
|
23 |
+
bash_command = "pip install --upgrade pip"
|
24 |
+
process = subprocess.Popen(bash_command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
25 |
+
|
26 |
+
# Load environment variables and API keys
|
27 |
+
# load_dotenv()
|
28 |
+
|
29 |
+
openai_kwargs = {
|
30 |
+
'api_key': os.getenv("OPENAI_API_KEY"),
|
31 |
+
'api_provider': os.getenv('OPENAI_API_TYPE'),
|
32 |
+
'temperature': 1.0,
|
33 |
+
'top_p': 0.9,
|
34 |
+
'api_base': os.getenv('AZURE_API_BASE'),
|
35 |
+
'api_version': os.getenv('AZURE_API_VERSION'),
|
36 |
+
}
|
37 |
+
|
38 |
+
|
39 |
+
lm = QwenModel(model='qwen-plus', max_tokens=1000, **openai_kwargs)
|
40 |
+
lm4outline = QwenModel(model='qwen-plus', max_tokens=1000, **openai_kwargs)
|
41 |
+
lm4gensection = QwenModel(model='qwen-plus', max_tokens=2000, **openai_kwargs)
|
42 |
+
lm4polish = QwenModel(model='qwen-plus', max_tokens=4000, **openai_kwargs)
|
43 |
+
|
44 |
+
|
45 |
+
rm = GoogleSearchAli_new(k=5)
|
46 |
+
|
47 |
+
|
48 |
+
|
49 |
+
|
50 |
+
st.set_page_config(page_title='OmniThink', layout="wide")
|
51 |
+
|
52 |
+
|
53 |
+
|
54 |
+
|
55 |
+
|
56 |
+
|
57 |
+
st.warning("Announcement: Due to the recent high volume of visitors, search API quota limitations, you may encounter an error: "
|
58 |
+
"'ValueError: Expected 2D array, got 1D array instead: array=[]. "
|
59 |
+
"Reshape your data either using array.reshape(-1, 1) if your data has a single feature "
|
60 |
+
"or array.reshape(1, -1) if it contains a single sample.' "
|
61 |
+
"If this error occurs, please try again in a few hours.")
|
62 |
+
|
63 |
+
st.title('🤔 OmniThink')
|
64 |
+
st.markdown('_OmniThink is a tool that helps you think deeply about a topic, generate an outline, and write an article._')
|
65 |
+
|
66 |
+
|
67 |
+
|
68 |
+
|
69 |
+
|
70 |
+
# Sidebar for configuration and examples
|
71 |
+
with st.sidebar:
|
72 |
+
st.header('Configuration')
|
73 |
+
MAX_ROUNDS = st.number_input('Retrieval Depth', min_value=0, max_value=10, value=2, step=1)
|
74 |
+
models = ['Qwen-Plus', 'Coming Soon']
|
75 |
+
selected_example = st.selectbox('LLM:', models)
|
76 |
+
searchers = ['GoogleSearch', 'Coming Soon']
|
77 |
+
selected_example = st.selectbox('Search engine', searchers)
|
78 |
+
|
79 |
+
n_max_doc = st.number_input('Number of web pages retrievad in single search', min_value=1, max_value=50, value=10, step=5)
|
80 |
+
st.header('Examples')
|
81 |
+
examples = ['AlphaFold', '2024 Hualien City Earthquake', 'Taylor Swift', 'Yoon Seok-youl']
|
82 |
+
selected_example = st.selectbox('case', examples)
|
83 |
+
status_placeholder = st.empty()
|
84 |
+
|
85 |
+
mind_map = MindMap(
|
86 |
+
retriever=rm,
|
87 |
+
gen_concept_lm = lm4outline,
|
88 |
+
gen_concept_lm2 = lm4outline,
|
89 |
+
search_top_k = n_max_doc,
|
90 |
+
depth= MAX_ROUNDS
|
91 |
+
)
|
92 |
+
|
93 |
+
def Think(input_topic):
|
94 |
+
|
95 |
+
generator = mind_map.build_map(input_topic)
|
96 |
+
|
97 |
+
st.markdown(f'Performing an in-depth search on the content related to {input_topic}...')
|
98 |
+
|
99 |
+
for idx, layer in enumerate(generator):
|
100 |
+
print(layer)
|
101 |
+
print('layer!!!')
|
102 |
+
st.markdown(f'Deep Thinking Retrieval at Level {idx + 1}...')
|
103 |
+
status_placeholder.text(f"Currently conducting the {idx + 1}th level deep thinking retrieval, estimated to take {(idx+1)*3} minutes.")
|
104 |
+
for node in layer:
|
105 |
+
category = node.category
|
106 |
+
|
107 |
+
print(f'category: {category}')
|
108 |
+
with st.expander(f'{category}'):
|
109 |
+
st.markdown(f'### The concept of {node.category}')
|
110 |
+
print(node.concept)
|
111 |
+
for concept in node.concept:
|
112 |
+
st.markdown(f'* {concept}')
|
113 |
+
st.markdown(f'### The web of {node.category}')
|
114 |
+
for idx, info in enumerate(node.info):
|
115 |
+
st.markdown(f'{idx + 1}. {info["title"]} \n {info["snippets"]}')
|
116 |
+
|
117 |
+
st.markdown(f'Constructing an index table for the {mind_map.get_web_number()} retrieved web pages...')
|
118 |
+
mind_map.prepare_table_for_retrieval()
|
119 |
+
return '__finish__', '__finish__'
|
120 |
+
|
121 |
+
def GenOutline(input_topic):
|
122 |
+
status_placeholder.text("The outline writing is in progress and is expected to take 1 minute.")
|
123 |
+
ogm = OutlineGenerationModule(lm)
|
124 |
+
outline = ogm.generate_outline(topic= input_topic, mindmap = mind_map)
|
125 |
+
|
126 |
+
return outline
|
127 |
+
|
128 |
+
def GenArticle(input_topic, outline):
|
129 |
+
status_placeholder.text("The article writing is in progress and is expected to take 3 minutes.")
|
130 |
+
|
131 |
+
article_with_outline = Article.from_outline_str(topic=input_topic, outline_str=outline)
|
132 |
+
ag = ArticleGenerationModule(retriever = rm, article_gen_lm = lm, retrieve_top_k = 3, max_thread_num = 10)
|
133 |
+
article = ag.generate_article(topic = topic, mindmap = mind_map, article_with_outline = article_with_outline)
|
134 |
+
ap = ArticlePolishingModule(article_gen_lm = lm, article_polish_lm = lm)
|
135 |
+
article = ap.polish_article(topic = topic, draft_article = article)
|
136 |
+
return article.to_string()
|
137 |
+
|
138 |
+
with st.form('my_form'):
|
139 |
+
topic = st.text_input('Please enter the topic you are interested in.', value=selected_example, placeholder='Please enter the topic you are interested in.')
|
140 |
+
submit_button = st.form_submit_button('Generate!')
|
141 |
+
|
142 |
+
if submit_button:
|
143 |
+
if topic:
|
144 |
+
st.markdown('### Thought process')
|
145 |
+
summary, news_timeline = Think(topic)
|
146 |
+
st.session_state.summary = summary
|
147 |
+
st.session_state.news_timeline = news_timeline
|
148 |
+
|
149 |
+
st.markdown('### Outline generation')
|
150 |
+
with st.expander("Outline generation", expanded=True):
|
151 |
+
outline = GenOutline(topic)
|
152 |
+
st.text(outline)
|
153 |
+
|
154 |
+
st.markdown('### article generation')
|
155 |
+
with st.expander("article generation", expanded=True):
|
156 |
+
article = GenArticle(topic, outline)
|
157 |
+
st.markdown(article)
|
158 |
+
else:
|
159 |
+
st.error('Please enter the subject.')
|
160 |
+
|
161 |
+
|
model/paraphrase-MiniLM-L6-v2/.gitattributes
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.bin.* filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.tar.gz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
17 |
+
model.safetensors filter=lfs diff=lfs merge=lfs -text
|
model/paraphrase-MiniLM-L6-v2/1_Pooling/config.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"word_embedding_dimension": 384,
|
3 |
+
"pooling_mode_cls_token": false,
|
4 |
+
"pooling_mode_mean_tokens": true,
|
5 |
+
"pooling_mode_max_tokens": false,
|
6 |
+
"pooling_mode_mean_sqrt_len_tokens": false
|
7 |
+
}
|
model/paraphrase-MiniLM-L6-v2/README.md
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: apache-2.0
|
3 |
+
library_name: sentence-transformers
|
4 |
+
tags:
|
5 |
+
- sentence-transformers
|
6 |
+
- feature-extraction
|
7 |
+
- sentence-similarity
|
8 |
+
- transformers
|
9 |
+
pipeline_tag: sentence-similarity
|
10 |
+
---
|
11 |
+
|
12 |
+
# sentence-transformers/paraphrase-MiniLM-L6-v2
|
13 |
+
|
14 |
+
This is a [sentence-transformers](https://www.SBERT.net) model: It maps sentences & paragraphs to a 384 dimensional dense vector space and can be used for tasks like clustering or semantic search.
|
15 |
+
|
16 |
+
|
17 |
+
|
18 |
+
## Usage (Sentence-Transformers)
|
19 |
+
|
20 |
+
Using this model becomes easy when you have [sentence-transformers](https://www.SBERT.net) installed:
|
21 |
+
|
22 |
+
```
|
23 |
+
pip install -U sentence-transformers
|
24 |
+
```
|
25 |
+
|
26 |
+
Then you can use the model like this:
|
27 |
+
|
28 |
+
```python
|
29 |
+
from sentence_transformers import SentenceTransformer
|
30 |
+
sentences = ["This is an example sentence", "Each sentence is converted"]
|
31 |
+
|
32 |
+
model = SentenceTransformer('sentence-transformers/paraphrase-MiniLM-L6-v2')
|
33 |
+
embeddings = model.encode(sentences)
|
34 |
+
print(embeddings)
|
35 |
+
```
|
36 |
+
|
37 |
+
|
38 |
+
|
39 |
+
## Usage (HuggingFace Transformers)
|
40 |
+
Without [sentence-transformers](https://www.SBERT.net), you can use the model like this: First, you pass your input through the transformer model, then you have to apply the right pooling-operation on-top of the contextualized word embeddings.
|
41 |
+
|
42 |
+
```python
|
43 |
+
from transformers import AutoTokenizer, AutoModel
|
44 |
+
import torch
|
45 |
+
|
46 |
+
|
47 |
+
#Mean Pooling - Take attention mask into account for correct averaging
|
48 |
+
def mean_pooling(model_output, attention_mask):
|
49 |
+
token_embeddings = model_output[0] #First element of model_output contains all token embeddings
|
50 |
+
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
51 |
+
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
52 |
+
|
53 |
+
|
54 |
+
# Sentences we want sentence embeddings for
|
55 |
+
sentences = ['This is an example sentence', 'Each sentence is converted']
|
56 |
+
|
57 |
+
# Load model from HuggingFace Hub
|
58 |
+
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/paraphrase-MiniLM-L6-v2')
|
59 |
+
model = AutoModel.from_pretrained('sentence-transformers/paraphrase-MiniLM-L6-v2')
|
60 |
+
|
61 |
+
# Tokenize sentences
|
62 |
+
encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
|
63 |
+
|
64 |
+
# Compute token embeddings
|
65 |
+
with torch.no_grad():
|
66 |
+
model_output = model(**encoded_input)
|
67 |
+
|
68 |
+
# Perform pooling. In this case, max pooling.
|
69 |
+
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
|
70 |
+
|
71 |
+
print("Sentence embeddings:")
|
72 |
+
print(sentence_embeddings)
|
73 |
+
```
|
74 |
+
|
75 |
+
|
76 |
+
|
77 |
+
## Evaluation Results
|
78 |
+
|
79 |
+
|
80 |
+
|
81 |
+
For an automated evaluation of this model, see the *Sentence Embeddings Benchmark*: [https://seb.sbert.net](https://seb.sbert.net?model_name=sentence-transformers/paraphrase-MiniLM-L6-v2)
|
82 |
+
|
83 |
+
|
84 |
+
|
85 |
+
## Full Model Architecture
|
86 |
+
```
|
87 |
+
SentenceTransformer(
|
88 |
+
(0): Transformer({'max_seq_length': 128, 'do_lower_case': False}) with Transformer model: BertModel
|
89 |
+
(1): Pooling({'word_embedding_dimension': 384, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})
|
90 |
+
)
|
91 |
+
```
|
92 |
+
|
93 |
+
## Citing & Authors
|
94 |
+
|
95 |
+
This model was trained by [sentence-transformers](https://www.sbert.net/).
|
96 |
+
|
97 |
+
If you find this model helpful, feel free to cite our publication [Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks](https://arxiv.org/abs/1908.10084):
|
98 |
+
```bibtex
|
99 |
+
@inproceedings{reimers-2019-sentence-bert,
|
100 |
+
title = "Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks",
|
101 |
+
author = "Reimers, Nils and Gurevych, Iryna",
|
102 |
+
booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing",
|
103 |
+
month = "11",
|
104 |
+
year = "2019",
|
105 |
+
publisher = "Association for Computational Linguistics",
|
106 |
+
url = "http://arxiv.org/abs/1908.10084",
|
107 |
+
}
|
108 |
+
```
|
model/paraphrase-MiniLM-L6-v2/config.json
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "old_models/paraphrase-MiniLM-L6-v2/0_Transformer",
|
3 |
+
"architectures": [
|
4 |
+
"BertModel"
|
5 |
+
],
|
6 |
+
"attention_probs_dropout_prob": 0.1,
|
7 |
+
"gradient_checkpointing": false,
|
8 |
+
"hidden_act": "gelu",
|
9 |
+
"hidden_dropout_prob": 0.1,
|
10 |
+
"hidden_size": 384,
|
11 |
+
"initializer_range": 0.02,
|
12 |
+
"intermediate_size": 1536,
|
13 |
+
"layer_norm_eps": 1e-12,
|
14 |
+
"max_position_embeddings": 512,
|
15 |
+
"model_type": "bert",
|
16 |
+
"num_attention_heads": 12,
|
17 |
+
"num_hidden_layers": 6,
|
18 |
+
"pad_token_id": 0,
|
19 |
+
"position_embedding_type": "absolute",
|
20 |
+
"transformers_version": "4.7.0",
|
21 |
+
"type_vocab_size": 2,
|
22 |
+
"use_cache": true,
|
23 |
+
"vocab_size": 30522
|
24 |
+
}
|
model/paraphrase-MiniLM-L6-v2/config_sentence_transformers.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"__version__": {
|
3 |
+
"sentence_transformers": "2.0.0",
|
4 |
+
"transformers": "4.7.0",
|
5 |
+
"pytorch": "1.9.0+cu102"
|
6 |
+
}
|
7 |
+
}
|
model/paraphrase-MiniLM-L6-v2/model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2ce4480dc3b2f8edeee50c43765c72768e79fc0113d3f73773dded4887cca298
|
3 |
+
size 90868373
|
model/paraphrase-MiniLM-L6-v2/modules.json
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[
|
2 |
+
{
|
3 |
+
"idx": 0,
|
4 |
+
"name": "0",
|
5 |
+
"path": "",
|
6 |
+
"type": "sentence_transformers.models.Transformer"
|
7 |
+
},
|
8 |
+
{
|
9 |
+
"idx": 1,
|
10 |
+
"name": "1",
|
11 |
+
"path": "1_Pooling",
|
12 |
+
"type": "sentence_transformers.models.Pooling"
|
13 |
+
}
|
14 |
+
]
|
model/paraphrase-MiniLM-L6-v2/onnx/model.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:441a5dc61ff3b889892feeb7aa0400518cc9908603209c45861ba3abef3006bc
|
3 |
+
size 90405214
|
model/paraphrase-MiniLM-L6-v2/onnx/model_O1.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e4b50dcc09ca71accf34b7c7d843ad157499bb2e2c7f7a9b9bc1bbb720147ce6
|
3 |
+
size 90360328
|
model/paraphrase-MiniLM-L6-v2/onnx/model_O2.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d99151ccdfb700fb278610e060f3debf615b59da275ba5784385d49c8b8e8e9c
|
3 |
+
size 90326566
|
model/paraphrase-MiniLM-L6-v2/onnx/model_O3.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:053f244528154eef2db50c676536cdf0ab1e9cba20693ad8c9d83cb592126072
|
3 |
+
size 90326497
|
model/paraphrase-MiniLM-L6-v2/onnx/model_O4.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:891a315753191d6dfb32c193615187456e33ee52d6425e0ad8dac2d086350f81
|
3 |
+
size 45212349
|
model/paraphrase-MiniLM-L6-v2/onnx/model_qint8_arm64.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ccc4bb68331e8410226d021ed709d6f2db3b0b25a43504828fa1d54fc6f7b3b3
|
3 |
+
size 23026053
|
model/paraphrase-MiniLM-L6-v2/onnx/model_qint8_avx512.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ccc4bb68331e8410226d021ed709d6f2db3b0b25a43504828fa1d54fc6f7b3b3
|
3 |
+
size 23026053
|
model/paraphrase-MiniLM-L6-v2/onnx/model_qint8_avx512_vnni.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ccc4bb68331e8410226d021ed709d6f2db3b0b25a43504828fa1d54fc6f7b3b3
|
3 |
+
size 23026053
|
model/paraphrase-MiniLM-L6-v2/onnx/model_quint8_avx2.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7d2e9f4180455601b4ebb64602ba667f551c87f16e791550479346851b6e4787
|
3 |
+
size 23046789
|
model/paraphrase-MiniLM-L6-v2/openvino/openvino_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c6005063ac5c88df685065089e887719f43956959a2080c7b9467bc17924645d
|
3 |
+
size 90265744
|
model/paraphrase-MiniLM-L6-v2/openvino/openvino_model.xml
ADDED
The diff for this file is too large to render.
See raw diff
|
|
model/paraphrase-MiniLM-L6-v2/openvino/openvino_model_qint8_quantized.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f036c75118e1df8040b4be3d5b7589ae1f1bb0c1f0f5d666b9bd317a2c8014d5
|
3 |
+
size 22933664
|
model/paraphrase-MiniLM-L6-v2/openvino/openvino_model_qint8_quantized.xml
ADDED
The diff for this file is too large to render.
See raw diff
|
|
model/paraphrase-MiniLM-L6-v2/pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5d716de760acbdc09e79a11e718c5606e0812b6aeb76c6664cba876d174e3ecd
|
3 |
+
size 90895153
|
model/paraphrase-MiniLM-L6-v2/sentence_bert_config.json
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"max_seq_length": 128,
|
3 |
+
"do_lower_case": false
|
4 |
+
}
|
model/paraphrase-MiniLM-L6-v2/special_tokens_map.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"}
|
model/paraphrase-MiniLM-L6-v2/tf_model.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ee09134d6f68fddf22606d3bc296855df94e527bbfe1555151e4b9613564a218
|
3 |
+
size 91005696
|
model/paraphrase-MiniLM-L6-v2/tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
model/paraphrase-MiniLM-L6-v2/tokenizer_config.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"do_lower_case": true, "unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]", "tokenize_chinese_chars": true, "strip_accents": null, "name_or_path": "nreimers/MiniLM-L6-H384-uncased", "do_basic_tokenize": true, "never_split": null, "model_max_length": 512}
|
model/paraphrase-MiniLM-L6-v2/vocab.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
requirements.txt
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dspy_ai==2.4.9
|
2 |
+
wikipedia==1.4.0
|
3 |
+
sentence-transformers
|
4 |
+
toml
|
5 |
+
langchain-text-splitters
|
6 |
+
trafilatura
|
7 |
+
langchain-huggingface
|
8 |
+
qdrant-client
|
9 |
+
langchain-qdrant
|
10 |
+
numpy==1.26.4
|
11 |
+
dashscope
|
12 |
+
beautifulsoup4
|
13 |
+
streamlit==1.37.1
|
14 |
+
python-dotenv
|
15 |
+
streamlit-vis-timeline==0.3.0
|
16 |
+
tilse
|
17 |
+
jsonlines
|
18 |
+
rank-bm25
|
19 |
+
transformers
|
20 |
+
litellm
|
21 |
+
lxml
|
22 |
+
lxml_html_clean
|
src/DeepThink/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (223 Bytes). View file
|
|
src/DeepThink/__pycache__/engine.cpython-311.pyc
ADDED
Binary file (18.2 kB). View file
|
|
src/DeepThink/engine.py
ADDED
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
from dataclasses import dataclass, field
|
5 |
+
from typing import Union, Literal, Optional
|
6 |
+
|
7 |
+
import dspy
|
8 |
+
from interface import Engine, LMConfigs
|
9 |
+
from lm import OpenAIModel
|
10 |
+
|
11 |
+
class LMConfigs():
|
12 |
+
"""Configurations for LLM used in different parts of STORM.
|
13 |
+
|
14 |
+
Given that different parts in STORM framework have different complexity, we use different LLM configurations
|
15 |
+
to achieve a balance between quality and efficiency. If no specific configuration is provided, we use the default
|
16 |
+
setup in the paper.
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(self):
|
20 |
+
self.conv_simulator_lm = None # LLM used in conversation simulator except for question asking.
|
21 |
+
self.question_asker_lm = None # LLM used in question asking.
|
22 |
+
self.outline_gen_lm = None # LLM used in outline generation.
|
23 |
+
self.article_gen_lm = None # LLM used in article generation.
|
24 |
+
self.article_polish_lm = None # LLM used in article polishing.
|
25 |
+
|
26 |
+
def set_lm(self, model: Union[dspy.dsp.LM, dspy.dsp.HFModel]):
|
27 |
+
self.lm = model
|
28 |
+
|
29 |
+
def set_conv_simulator_lm(self, model: Union[dspy.dsp.LM, dspy.dsp.HFModel]):
|
30 |
+
self.conv_simulator_lm = model
|
31 |
+
|
32 |
+
def set_question_asker_lm(self, model: Union[dspy.dsp.LM, dspy.dsp.HFModel]):
|
33 |
+
self.question_asker_lm = model
|
34 |
+
|
35 |
+
def set_outline_gen_lm(self, model: Union[dspy.dsp.LM, dspy.dsp.HFModel]):
|
36 |
+
self.outline_gen_lm = model
|
37 |
+
|
38 |
+
def set_article_gen_lm(self, model: Union[dspy.dsp.LM, dspy.dsp.HFModel]):
|
39 |
+
self.article_gen_lm = model
|
40 |
+
|
41 |
+
def set_article_polish_lm(self, model: Union[dspy.dsp.LM, dspy.dsp.HFModel]):
|
42 |
+
self.article_polish_lm = model
|
43 |
+
|
44 |
+
|
45 |
+
@dataclass
|
46 |
+
class RunnerArguments:
|
47 |
+
"""Arguments for controlling the STORM Wiki pipeline."""
|
48 |
+
output_dir: str = field(
|
49 |
+
metadata={"help": "Output directory for the results."},
|
50 |
+
)
|
51 |
+
max_conv_turn: int = field(
|
52 |
+
default=3,
|
53 |
+
metadata={"help": "Maximum number of questions in conversational question asking."},
|
54 |
+
)
|
55 |
+
max_perspective: int = field(
|
56 |
+
default=3,
|
57 |
+
metadata={"help": "Maximum number of perspectives to consider in perspective-guided question asking."},
|
58 |
+
)
|
59 |
+
max_search_queries_per_turn: int = field(
|
60 |
+
default=3,
|
61 |
+
metadata={"help": "Maximum number of search queries to consider in each turn."},
|
62 |
+
)
|
63 |
+
disable_perspective: bool = field(
|
64 |
+
default=False,
|
65 |
+
metadata={"help": "If True, disable perspective-guided question asking."},
|
66 |
+
)
|
67 |
+
search_top_k: int = field(
|
68 |
+
default=3,
|
69 |
+
metadata={"help": "Top k search results to consider for each search query."},
|
70 |
+
)
|
71 |
+
retrieve_top_k: int = field(
|
72 |
+
default=3,
|
73 |
+
metadata={"help": "Top k collected references for each section title."},
|
74 |
+
)
|
75 |
+
max_thread_num: int = field(
|
76 |
+
default=10,
|
77 |
+
metadata={"help": "Maximum number of threads to use. "
|
78 |
+
"Consider reducing it if keep getting 'Exceed rate limit' error when calling LM API."},
|
79 |
+
)
|
80 |
+
|
81 |
+
class Runner():
|
82 |
+
"""STORM Wiki pipeline runner."""
|
83 |
+
|
84 |
+
def __init__(self,
|
85 |
+
args: RunnerArguments,
|
86 |
+
lm_configs: LMConfigs,
|
87 |
+
rm):
|
88 |
+
super().__init__(lm_configs=lm_configs)
|
89 |
+
self.args = args
|
90 |
+
self.lm_configs = lm_configs
|
91 |
+
|
92 |
+
self.retriever = StormRetriever(rm=rm, k=self.args.retrieve_top_k)
|
93 |
+
storm_persona_generator = StormPersonaGenerator(self.lm_configs.question_asker_lm)
|
94 |
+
self.storm_knowledge_curation_module = StormKnowledgeCurationModule(
|
95 |
+
retriever=self.retriever,
|
96 |
+
persona_generator=storm_persona_generator,
|
97 |
+
conv_simulator_lm=self.lm_configs.conv_simulator_lm,
|
98 |
+
question_asker_lm=self.lm_configs.question_asker_lm,
|
99 |
+
max_search_queries_per_turn=self.args.max_search_queries_per_turn,
|
100 |
+
search_top_k=self.args.search_top_k,
|
101 |
+
max_conv_turn=self.args.max_conv_turn,
|
102 |
+
max_thread_num=self.args.max_thread_num
|
103 |
+
)
|
104 |
+
|
105 |
+
self.storm_outline_generation_module = StormOutlineGenerationModule(
|
106 |
+
outline_gen_lm=self.lm_configs.outline_gen_lm
|
107 |
+
)
|
108 |
+
|
109 |
+
self.storm_article_generation = StormArticleGenerationModule(
|
110 |
+
article_gen_lm=self.lm_configs.article_gen_lm,
|
111 |
+
retrieve_top_k=self.args.retrieve_top_k,
|
112 |
+
max_thread_num=self.args.max_thread_num,
|
113 |
+
retriever =self.retriever
|
114 |
+
)
|
115 |
+
|
116 |
+
self.storm_article_polishing_module = StormArticlePolishingModule(
|
117 |
+
article_gen_lm=self.lm_configs.article_gen_lm,
|
118 |
+
article_polish_lm=self.lm_configs.article_polish_lm
|
119 |
+
)
|
120 |
+
|
121 |
+
self.lm_configs.init_check()
|
122 |
+
self.apply_decorators()
|
123 |
+
|
124 |
+
def run_knowledge_curation_module(self,
|
125 |
+
ground_truth_url: str = "None",
|
126 |
+
) -> StormInformationTable:
|
127 |
+
#第一次进入的地方,此处还是原topic,information_table既有所���的conversation对话又有所有的url和snippet的对应dict
|
128 |
+
information_table, conversation_log = self.storm_knowledge_curation_module.research(
|
129 |
+
topic=self.topic,
|
130 |
+
ground_truth_url=ground_truth_url,
|
131 |
+
callback_handler=callback_handler,
|
132 |
+
max_perspective=self.args.max_perspective,
|
133 |
+
disable_perspective=False,
|
134 |
+
return_conversation_log=True
|
135 |
+
)
|
136 |
+
|
137 |
+
FileIOHelper.dump_json(conversation_log, os.path.join(self.article_output_dir, 'conversation_log.json'))
|
138 |
+
information_table.dump_url_to_info(os.path.join(self.article_output_dir, 'raw_search_results.json'))
|
139 |
+
|
140 |
+
return information_table
|
141 |
+
|
142 |
+
def run_outline_generation_module(self,
|
143 |
+
information_table: StormInformationTable,
|
144 |
+
callback_handler: BaseCallbackHandler = None) -> StormArticle:
|
145 |
+
|
146 |
+
outline, draft_outline = self.storm_outline_generation_module.generate_outline(
|
147 |
+
topic=self.topic,
|
148 |
+
information_table=information_table,
|
149 |
+
return_draft_outline=True,
|
150 |
+
callback_handler=callback_handler
|
151 |
+
)
|
152 |
+
outline.dump_outline_to_file(os.path.join(self.article_output_dir, 'storm_gen_outline.txt'))
|
153 |
+
draft_outline.dump_outline_to_file(os.path.join(self.article_output_dir, "direct_gen_outline.txt"))
|
154 |
+
return outline
|
155 |
+
|
156 |
+
def run_article_generation_module(self,
|
157 |
+
outline: StormArticle,
|
158 |
+
information_table=StormInformationTable,
|
159 |
+
callback_handler: BaseCallbackHandler = None) -> StormArticle:
|
160 |
+
|
161 |
+
draft_article = self.storm_article_generation.generate_article(
|
162 |
+
topic=self.topic,
|
163 |
+
information_table=information_table,
|
164 |
+
article_with_outline=outline,
|
165 |
+
callback_handler=callback_handler
|
166 |
+
)
|
167 |
+
draft_article.dump_article_as_plain_text(os.path.join(self.article_output_dir, 'storm_gen_article.txt'))
|
168 |
+
draft_article.dump_reference_to_file(os.path.join(self.article_output_dir, 'url_to_info.json'))
|
169 |
+
return draft_article
|
170 |
+
|
171 |
+
def run_article_polishing_module(self,
|
172 |
+
draft_article: StormArticle,
|
173 |
+
remove_duplicate: bool = False) -> StormArticle:
|
174 |
+
|
175 |
+
polished_article = self.storm_article_polishing_module.polish_article(
|
176 |
+
topic=self.topic,
|
177 |
+
draft_article=draft_article,
|
178 |
+
remove_duplicate=remove_duplicate
|
179 |
+
)
|
180 |
+
FileIOHelper.write_str(polished_article.to_string(),
|
181 |
+
os.path.join(self.article_output_dir, 'storm_gen_article_polished.txt'))
|
182 |
+
return polished_article
|
183 |
+
|
184 |
+
def post_run(self):
|
185 |
+
"""
|
186 |
+
Post-run operations, including:
|
187 |
+
1. Dumping the run configuration.
|
188 |
+
2. Dumping the LLM call history.
|
189 |
+
"""
|
190 |
+
config_log = self.lm_configs.log()
|
191 |
+
FileIOHelper.dump_json(config_log, os.path.join(self.article_output_dir, 'run_config.json'))
|
192 |
+
|
193 |
+
llm_call_history = self.lm_configs.collect_and_reset_lm_history()
|
194 |
+
with open(os.path.join(self.article_output_dir, 'llm_call_history.jsonl'), 'w') as f:
|
195 |
+
for call in llm_call_history:
|
196 |
+
if 'kwargs' in call:
|
197 |
+
call.pop('kwargs') # All kwargs are dumped together to run_config.json.
|
198 |
+
f.write(json.dumps(call) + '\n')
|
199 |
+
|
200 |
+
def _load_information_table_from_local_fs(self, information_table_local_path):
|
201 |
+
assert os.path.exists(information_table_local_path), makeStringRed(f"{information_table_local_path} not exists. Please set --do-research argument to prepare the conversation_log.json for this topic.")
|
202 |
+
return StormInformationTable.from_conversation_log_file(information_table_local_path)
|
203 |
+
|
204 |
+
def _load_outline_from_local_fs(self, topic, outline_local_path):
|
205 |
+
assert os.path.exists(outline_local_path), makeStringRed(f"{outline_local_path} not exists. Please set --do-generate-outline argument to prepare the storm_gen_outline.txt for this topic.")
|
206 |
+
return StormArticle.from_outline_file(topic=topic, file_path=outline_local_path)
|
207 |
+
|
208 |
+
def _load_draft_article_from_local_fs(self, topic, draft_article_path, url_to_info_path):
|
209 |
+
assert os.path.exists(draft_article_path), makeStringRed(f"{draft_article_path} not exists. Please set --do-generate-article argument to prepare the storm_gen_article.txt for this topic.")
|
210 |
+
assert os.path.exists(url_to_info_path), makeStringRed(f"{url_to_info_path} not exists. Please set --do-generate-article argument to prepare the url_to_info.json for this topic.")
|
211 |
+
article_text = FileIOHelper.load_str(draft_article_path)
|
212 |
+
references = FileIOHelper.load_json(url_to_info_path)
|
213 |
+
return StormArticle.from_string(topic_name=topic, article_text=article_text, references=references)
|
214 |
+
|
215 |
+
def run(self,
|
216 |
+
topic: str,
|
217 |
+
ground_truth_url: str = '',
|
218 |
+
do_research: bool = True,
|
219 |
+
do_generate_outline: bool = True,
|
220 |
+
do_generate_article: bool = True,
|
221 |
+
do_polish_article: bool = True,
|
222 |
+
remove_duplicate: bool = False,
|
223 |
+
callback_handler: BaseCallbackHandler = BaseCallbackHandler()):
|
224 |
+
"""
|
225 |
+
Run the STORM pipeline.
|
226 |
+
|
227 |
+
Args:
|
228 |
+
topic: The topic to research.
|
229 |
+
ground_truth_url: A ground truth URL including a curated article about the topic. The URL will be excluded.
|
230 |
+
do_research: If True, research the topic through information-seeking conversation;
|
231 |
+
if False, expect conversation_log.json and raw_search_results.json to exist in the output directory.
|
232 |
+
do_generate_outline: If True, generate an outline for the topic;
|
233 |
+
if False, expect storm_gen_outline.txt to exist in the output directory.
|
234 |
+
do_generate_article: If True, generate a curated article for the topic;
|
235 |
+
if False, expect storm_gen_article.txt to exist in the output directory.
|
236 |
+
do_polish_article: If True, polish the article by adding a summarization section and (optionally) removing
|
237 |
+
duplicated content.
|
238 |
+
remove_duplicate: If True, remove duplicated content.
|
239 |
+
callback_handler: A callback handler to handle the intermediate results.
|
240 |
+
"""
|
241 |
+
assert do_research or do_generate_outline or do_generate_article or do_polish_article, \
|
242 |
+
makeStringRed("No action is specified. Please set at least one of --do-research, --do-generate-outline, --do-generate-article, --do-polish-article")
|
243 |
+
|
244 |
+
self.topic = topic
|
245 |
+
self.article_dir_name = topic.replace(' ', '_').replace('/', '_')
|
246 |
+
self.article_output_dir = os.path.join(self.args.output_dir, self.article_dir_name)
|
247 |
+
os.makedirs(self.article_output_dir, exist_ok=True)
|
248 |
+
|
249 |
+
# research module,先自己生成一些链接得到一些url,然后读取url生成一些不同的人格,然后对不同的人格进行对话得到有用信息
|
250 |
+
information_table: StormInformationTable = None
|
251 |
+
if do_research:
|
252 |
+
information_table = self.run_knowledge_curation_module(ground_truth_url=ground_truth_url,
|
253 |
+
callback_handler=callback_handler)
|
254 |
+
|
255 |
+
# outline generation module,这地方就是生成一些outline,可以选择根据前面的conversation进行生成outline会更详细一些
|
256 |
+
outline: StormArticle = None
|
257 |
+
if do_generate_outline:
|
258 |
+
# load information table if it's not initialized
|
259 |
+
if information_table is None:
|
260 |
+
information_table = self._load_information_table_from_local_fs(os.path.join(self.article_output_dir, '.json'))
|
261 |
+
outline = self.run_outline_generation_module(information_table=information_table,
|
262 |
+
callback_handler=callback_handler)
|
263 |
+
|
264 |
+
|
265 |
+
# article generation module
|
266 |
+
draft_article: StormArticle = None
|
267 |
+
if do_generate_article:
|
268 |
+
if information_table is None:
|
269 |
+
information_table = self._load_information_table_from_local_fs(os.path.join(self.article_output_dir, 'conversation_log.json'))
|
270 |
+
if outline is None:
|
271 |
+
outline = self._load_outline_from_local_fs(topic=topic, outline_local_path=os.path.join(self.article_output_dir, 'storm_gen_outline.txt'))
|
272 |
+
|
273 |
+
draft_article = self.run_article_generation_module(outline=outline,
|
274 |
+
information_table=information_table,
|
275 |
+
callback_handler=callback_handler)
|
276 |
+
|
277 |
+
# article polishing module
|
278 |
+
if do_polish_article:
|
279 |
+
if draft_article is None:
|
280 |
+
draft_article_path = os.path.join(self.article_output_dir, 'storm_gen_article.txt')
|
281 |
+
url_to_info_path = os.path.join(self.article_output_dir, 'url_to_info.json')
|
282 |
+
draft_article = self._load_draft_article_from_local_fs(topic=topic, draft_article_path=draft_article_path, url_to_info_path=url_to_info_path)
|
283 |
+
self.run_article_polishing_module(draft_article=draft_article, remove_duplicate=remove_duplicate)
|
284 |
+
|
285 |
+
post_polish(self.article_output_dir)
|
src/DeepThink/modules/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (242 Bytes). View file
|
|
src/DeepThink/modules/__pycache__/article_generation.cpython-310.pyc
ADDED
Binary file (6.43 kB). View file
|
|
src/DeepThink/modules/__pycache__/article_generation.cpython-311.pyc
ADDED
Binary file (10.5 kB). View file
|
|
src/DeepThink/modules/__pycache__/article_polish.cpython-310.pyc
ADDED
Binary file (3.4 kB). View file
|
|
src/DeepThink/modules/__pycache__/article_polish.cpython-311.pyc
ADDED
Binary file (5.13 kB). View file
|
|
src/DeepThink/modules/__pycache__/interface.cpython-310.pyc
ADDED
Binary file (17.2 kB). View file
|
|
src/DeepThink/modules/__pycache__/interface.cpython-311.pyc
ADDED
Binary file (24.2 kB). View file
|
|
src/DeepThink/modules/__pycache__/mindmap.cpython-310.pyc
ADDED
Binary file (14.1 kB). View file
|
|
src/DeepThink/modules/__pycache__/mindmap.cpython-311.pyc
ADDED
Binary file (25.4 kB). View file
|
|
src/DeepThink/modules/__pycache__/outline_generation.cpython-310.pyc
ADDED
Binary file (4.84 kB). View file
|
|
src/DeepThink/modules/__pycache__/outline_generation.cpython-311.pyc
ADDED
Binary file (7.47 kB). View file
|
|
src/DeepThink/modules/__pycache__/retriever.cpython-311.pyc
ADDED
Binary file (3.81 kB). View file
|
|
src/DeepThink/modules/__pycache__/storm_dataclass.cpython-310.pyc
ADDED
Binary file (12.6 kB). View file
|
|
src/DeepThink/modules/__pycache__/storm_dataclass.cpython-311.pyc
ADDED
Binary file (21.6 kB). View file
|
|
src/DeepThink/modules/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (12.5 kB). View file
|
|
src/DeepThink/modules/__pycache__/utils.cpython-311.pyc
ADDED
Binary file (20.4 kB). View file
|
|
src/DeepThink/modules/article_generation.py
ADDED
@@ -0,0 +1,523 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import concurrent.futures
|
2 |
+
import copy
|
3 |
+
import logging
|
4 |
+
from concurrent.futures import as_completed
|
5 |
+
from typing import List, Union
|
6 |
+
import random
|
7 |
+
import dspy
|
8 |
+
import sys
|
9 |
+
|
10 |
+
|
11 |
+
import concurrent.futures
|
12 |
+
import json
|
13 |
+
import os
|
14 |
+
import pickle
|
15 |
+
import re
|
16 |
+
import sys
|
17 |
+
from typing import List, Dict
|
18 |
+
|
19 |
+
import httpx
|
20 |
+
import toml
|
21 |
+
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
22 |
+
from trafilatura import extract
|
23 |
+
|
24 |
+
|
25 |
+
class ArticleTextProcessing:
|
26 |
+
@staticmethod
|
27 |
+
def limit_word_count_preserve_newline(input_string, max_word_count):
|
28 |
+
"""
|
29 |
+
Limit the word count of an input string to a specified maximum, while preserving the integrity of complete lines.
|
30 |
+
|
31 |
+
The function truncates the input string at the nearest word that does not exceed the maximum word count,
|
32 |
+
ensuring that no partial lines are included in the output. Words are defined as text separated by spaces,
|
33 |
+
and lines are defined as text separated by newline characters.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
input_string (str): The string to be truncated. This string may contain multiple lines.
|
37 |
+
max_word_count (int): The maximum number of words allowed in the truncated string.
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
str: The truncated string with word count limited to `max_word_count`, preserving complete lines.
|
41 |
+
"""
|
42 |
+
|
43 |
+
word_count = 0
|
44 |
+
limited_string = ''
|
45 |
+
|
46 |
+
for word in input_string.split('\n'):
|
47 |
+
line_words = word.split()
|
48 |
+
for lw in line_words:
|
49 |
+
if word_count < max_word_count:
|
50 |
+
limited_string += lw + ' '
|
51 |
+
word_count += 1
|
52 |
+
else:
|
53 |
+
break
|
54 |
+
if word_count >= max_word_count:
|
55 |
+
break
|
56 |
+
limited_string = limited_string.strip() + '\n'
|
57 |
+
|
58 |
+
return limited_string.strip()
|
59 |
+
|
60 |
+
@staticmethod
|
61 |
+
def remove_citations(s):
|
62 |
+
"""
|
63 |
+
Removes all citations from a given string. Citations are assumed to be in the format
|
64 |
+
of numbers enclosed in square brackets, such as [1], [2], or [1, 2], etc. This function searches
|
65 |
+
for all occurrences of such patterns and removes them, returning the cleaned string.
|
66 |
+
|
67 |
+
Args:
|
68 |
+
s (str): The string from which citations are to be removed.
|
69 |
+
|
70 |
+
Returns:
|
71 |
+
str: The string with all citation patterns removed.
|
72 |
+
"""
|
73 |
+
|
74 |
+
return re.sub(r'\[\d+(?:,\s*\d+)*\]', '', s)
|
75 |
+
|
76 |
+
@staticmethod
|
77 |
+
def get_first_section_dict_and_list(s):
|
78 |
+
"""
|
79 |
+
"""
|
80 |
+
text = s
|
81 |
+
sections = text.strip().split('\n# ')
|
82 |
+
titles = []
|
83 |
+
content_dict = {}
|
84 |
+
|
85 |
+
for section in sections:
|
86 |
+
if section:
|
87 |
+
lines = section.split('\n', 1)
|
88 |
+
title = lines[0].strip()
|
89 |
+
content = lines[1].strip() if len(lines) > 1 else ""
|
90 |
+
|
91 |
+
titles.append(title)
|
92 |
+
content_dict[title] = content
|
93 |
+
return content_dict, titles
|
94 |
+
|
95 |
+
@staticmethod
|
96 |
+
def parse_citation_indices(s):
|
97 |
+
"""
|
98 |
+
Extracts citation indexes from the provided content string and returns them as a list of integers.
|
99 |
+
|
100 |
+
Args:
|
101 |
+
content (str): The content string containing citations in the format [number].
|
102 |
+
|
103 |
+
Returns:
|
104 |
+
List[int]: A list of unique citation indexes extracted from the content, in the order they appear.
|
105 |
+
"""
|
106 |
+
matches = re.findall(r'\[\d+\]', s)
|
107 |
+
return [int(index[1:-1]) for index in matches]
|
108 |
+
|
109 |
+
@staticmethod
|
110 |
+
def remove_uncompleted_sentences_with_citations(text):
|
111 |
+
"""
|
112 |
+
Removes uncompleted sentences and standalone citations from the input text. Sentences are identified
|
113 |
+
by their ending punctuation (.!?), optionally followed by a citation in square brackets (e.g., "[1]").
|
114 |
+
Grouped citations (e.g., "[1, 2]") are split into individual ones (e.g., "[1] [2]"). Only text up to
|
115 |
+
and including the last complete sentence and its citation is retained.
|
116 |
+
|
117 |
+
Args:
|
118 |
+
text (str): The input text from which uncompleted sentences and their citations are to be removed.
|
119 |
+
|
120 |
+
Returns:
|
121 |
+
str: The processed string with uncompleted sentences and standalone citations removed, leaving only
|
122 |
+
complete sentences and their associated citations if present.
|
123 |
+
"""
|
124 |
+
|
125 |
+
# Convert citations like [1, 2, 3] to [1][2][3].
|
126 |
+
def replace_with_individual_brackets(match):
|
127 |
+
numbers = match.group(1).split(', ')
|
128 |
+
return ' '.join(f'[{n}]' for n in numbers)
|
129 |
+
|
130 |
+
# Deduplicate and sort individual groups of citations.
|
131 |
+
def deduplicate_group(match):
|
132 |
+
citations = match.group(0)
|
133 |
+
unique_citations = list(set(re.findall(r'\[\d+\]', citations)))
|
134 |
+
sorted_citations = sorted(unique_citations, key=lambda x: int(x.strip('[]')))
|
135 |
+
# Return the sorted unique citations as a string
|
136 |
+
return ''.join(sorted_citations)
|
137 |
+
|
138 |
+
text = re.sub(r'\[([0-9, ]+)\]', replace_with_individual_brackets, text)
|
139 |
+
text = re.sub(r'(\[\d+\])+', deduplicate_group, text)
|
140 |
+
|
141 |
+
# Deprecated: Remove sentence without proper ending punctuation and citations.
|
142 |
+
# Split the text into sentences (including citations).
|
143 |
+
# sentences_with_trailing = re.findall(r'([^.!?]*[.!?].*?)(?=[^.!?]*[.!?]|$)', text)
|
144 |
+
|
145 |
+
# Filter sentences to ensure they end with a punctuation mark and properly formatted citations
|
146 |
+
# complete_sentences = []
|
147 |
+
# for sentence in sentences_with_trailing:
|
148 |
+
# # Check if the sentence ends with properly formatted citations
|
149 |
+
# if re.search(r'[.!?]( \[\d+\])*$|^[^.!?]*[.!?]$', sentence.strip()):
|
150 |
+
# complete_sentences.append(sentence.strip())
|
151 |
+
|
152 |
+
# combined_sentences = ' '.join(complete_sentences)
|
153 |
+
|
154 |
+
# Check for and append any complete citations that follow the last sentence
|
155 |
+
# trailing_citations = re.findall(r'(\[\d+\]) ', text[text.rfind(combined_sentences) + len(combined_sentences):])
|
156 |
+
# if trailing_citations:
|
157 |
+
# combined_sentences += ' '.join(trailing_citations)
|
158 |
+
|
159 |
+
# Regex pattern to match sentence endings, including optional citation markers.
|
160 |
+
eos_pattern = r'([.!?])\s*(\[\d+\])?\s*'
|
161 |
+
matches = list(re.finditer(eos_pattern, text))
|
162 |
+
if matches:
|
163 |
+
last_match = matches[-1]
|
164 |
+
text = text[:last_match.end()].strip()
|
165 |
+
|
166 |
+
return text
|
167 |
+
|
168 |
+
@staticmethod
|
169 |
+
def clean_up_citation(conv):
|
170 |
+
for turn in conv.dlg_history:
|
171 |
+
turn.agent_utterance = turn.agent_utterance[:turn.agent_utterance.find('References:')]
|
172 |
+
turn.agent_utterance = turn.agent_utterance[:turn.agent_utterance.find('Sources:')]
|
173 |
+
turn.agent_utterance = turn.agent_utterance.replace('Answer:', '').strip()
|
174 |
+
try:
|
175 |
+
max_ref_num = max([int(x) for x in re.findall(r'\[(\d+)\]', turn.agent_utterance)])
|
176 |
+
except Exception as e:
|
177 |
+
max_ref_num = 0
|
178 |
+
if max_ref_num > len(turn.search_results):
|
179 |
+
for i in range(len(turn.search_results), max_ref_num + 1):
|
180 |
+
turn.agent_utterance = turn.agent_utterance.replace(f'[{i}]', '')
|
181 |
+
turn.agent_utterance = ArticleTextProcessing.remove_uncompleted_sentences_with_citations(
|
182 |
+
turn.agent_utterance)
|
183 |
+
|
184 |
+
return conv
|
185 |
+
|
186 |
+
@staticmethod
|
187 |
+
def clean_up_outline(outline, topic=""):
|
188 |
+
output_lines = []
|
189 |
+
current_level = 0 # To track the current section level
|
190 |
+
|
191 |
+
for line in outline.split('\n'):
|
192 |
+
stripped_line = line.strip()
|
193 |
+
|
194 |
+
if topic != "" and f"# {topic.lower()}" in stripped_line.lower():
|
195 |
+
output_lines = []
|
196 |
+
|
197 |
+
# Check if the line is a section header
|
198 |
+
if stripped_line.startswith('#') and stripped_line != '#':
|
199 |
+
current_level = stripped_line.count('#')
|
200 |
+
output_lines.append(stripped_line)
|
201 |
+
# Check if the line is a bullet point
|
202 |
+
# elif stripped_line.startswith('-'):
|
203 |
+
# subsection_header = '#' * (current_level + 1) + ' ' + stripped_line[1:].strip()
|
204 |
+
# output_lines.append(subsection_header)
|
205 |
+
# Preserve lines with @
|
206 |
+
elif stripped_line.startswith('@'):
|
207 |
+
output_lines.append(stripped_line)
|
208 |
+
|
209 |
+
outline = '\n'.join(output_lines)
|
210 |
+
|
211 |
+
# Remove references.
|
212 |
+
outline = re.sub(r"#[#]? See also.*?(?=##|$)", '', outline, flags=re.DOTALL)
|
213 |
+
outline = re.sub(r"#[#]? See Also.*?(?=##|$)", '', outline, flags=re.DOTALL)
|
214 |
+
outline = re.sub(r"#[#]? Notes.*?(?=##|$)", '', outline, flags=re.DOTALL)
|
215 |
+
outline = re.sub(r"#[#]? References.*?(?=##|$)", '', outline, flags=re.DOTALL)
|
216 |
+
outline = re.sub(r"#[#]? External links.*?(?=##|$)", '', outline, flags=re.DOTALL)
|
217 |
+
outline = re.sub(r"#[#]? External Links.*?(?=##|$)", '', outline, flags=re.DOTALL)
|
218 |
+
outline = re.sub(r"#[#]? Bibliography.*?(?=##|$)", '', outline, flags=re.DOTALL)
|
219 |
+
outline = re.sub(r"#[#]? Further reading*?(?=##|$)", '', outline, flags=re.DOTALL)
|
220 |
+
outline = re.sub(r"#[#]? Further Reading*?(?=##|$)", '', outline, flags=re.DOTALL)
|
221 |
+
outline = re.sub(r"#[#]? Summary.*?(?=##|$)", '', outline, flags=re.DOTALL)
|
222 |
+
outline = re.sub(r"#[#]? Appendices.*?(?=##|$)", '', outline, flags=re.DOTALL)
|
223 |
+
outline = re.sub(r"#[#]? Appendix.*?(?=##|$)", '', outline, flags=re.DOTALL)
|
224 |
+
|
225 |
+
return outline
|
226 |
+
|
227 |
+
|
228 |
+
@staticmethod
|
229 |
+
def clean_up_section(text):
|
230 |
+
"""Clean up a section:
|
231 |
+
1. Remove uncompleted sentences (usually due to output token limitation).
|
232 |
+
2. Deduplicate individual groups of citations.
|
233 |
+
3. Remove unnecessary summary."""
|
234 |
+
|
235 |
+
paragraphs = text.split('\n')
|
236 |
+
output_paragraphs = []
|
237 |
+
summary_sec_flag = False
|
238 |
+
for p in paragraphs:
|
239 |
+
p = p.strip()
|
240 |
+
if len(p) == 0:
|
241 |
+
continue
|
242 |
+
if not p.startswith('#'):
|
243 |
+
p = ArticleTextProcessing.remove_uncompleted_sentences_with_citations(p)
|
244 |
+
if summary_sec_flag:
|
245 |
+
if p.startswith('#'):
|
246 |
+
summary_sec_flag = False
|
247 |
+
else:
|
248 |
+
continue
|
249 |
+
if p.startswith('Overall') or p.startswith('In summary') or p.startswith('In conclusion'):
|
250 |
+
continue
|
251 |
+
if "# Summary" in p or '# Conclusion' in p:
|
252 |
+
summary_sec_flag = True
|
253 |
+
continue
|
254 |
+
output_paragraphs.append(p)
|
255 |
+
|
256 |
+
return '\n\n'.join(output_paragraphs) # Join with '\n\n' for markdown format.
|
257 |
+
|
258 |
+
@staticmethod
|
259 |
+
def update_citation_index(s, citation_map):
|
260 |
+
"""Update citation index in the string based on the citation map."""
|
261 |
+
for original_citation in citation_map:
|
262 |
+
s = s.replace(f"[{original_citation}]", f"__PLACEHOLDER_{original_citation}__")
|
263 |
+
for original_citation, unify_citation in citation_map.items():
|
264 |
+
s = s.replace(f"__PLACEHOLDER_{original_citation}__", f"[{unify_citation}]")
|
265 |
+
|
266 |
+
return s
|
267 |
+
|
268 |
+
@staticmethod
|
269 |
+
def parse_article_into_dict(input_string):
|
270 |
+
"""
|
271 |
+
Parses a structured text into a nested dictionary. The structure of the text
|
272 |
+
is defined by markdown-like headers (using '#' symbols) to denote sections
|
273 |
+
and subsections. Each section can contain content and further nested subsections.
|
274 |
+
|
275 |
+
The resulting dictionary captures the hierarchical structure of sections, where
|
276 |
+
each section is represented as a key (the section's title) mapping to a value
|
277 |
+
that is another dictionary. This dictionary contains two keys:
|
278 |
+
- 'content': content of the section
|
279 |
+
- 'subsections': a list of dictionaries, each representing a nested subsection
|
280 |
+
following the same structure.
|
281 |
+
|
282 |
+
Args:
|
283 |
+
input_string (str): A string containing the structured text to parse.
|
284 |
+
|
285 |
+
Returns:
|
286 |
+
A dictionary representing contains the section title as the key, and another dictionary
|
287 |
+
as the value, which includes the 'content' and 'subsections' keys as described above.
|
288 |
+
"""
|
289 |
+
lines = input_string.split('\n')
|
290 |
+
lines = [line for line in lines if line.strip()]
|
291 |
+
root = {'content': '', 'subsections': {}}
|
292 |
+
current_path = [(root, -1)] # (current_dict, level)
|
293 |
+
|
294 |
+
for line in lines:
|
295 |
+
if line.startswith('#'):
|
296 |
+
level = line.count('#')
|
297 |
+
title = line.strip('# ').strip()
|
298 |
+
new_section = {'content': '', 'subsections': {}}
|
299 |
+
|
300 |
+
# Pop from stack until find the parent level
|
301 |
+
while current_path and current_path[-1][1] >= level:
|
302 |
+
current_path.pop()
|
303 |
+
|
304 |
+
# Append new section to the nearest upper level's subsections
|
305 |
+
current_path[-1][0]['subsections'][title] = new_section
|
306 |
+
current_path.append((new_section, level))
|
307 |
+
else:
|
308 |
+
current_path[-1][0]['content'] += line + '\n'
|
309 |
+
|
310 |
+
return root['subsections']
|
311 |
+
|
312 |
+
|
313 |
+
class FileIOHelper:
|
314 |
+
@staticmethod
|
315 |
+
def dump_json(obj, file_name, encoding="utf-8"):
|
316 |
+
with open(file_name, 'w', encoding=encoding) as fw:
|
317 |
+
json.dump(obj, fw, default=FileIOHelper.handle_non_serializable, ensure_ascii=False)
|
318 |
+
|
319 |
+
@staticmethod
|
320 |
+
def handle_non_serializable(obj):
|
321 |
+
return "non-serializable contents" # mark the non-serializable part
|
322 |
+
|
323 |
+
@staticmethod
|
324 |
+
def load_json(file_name, encoding="utf-8"):
|
325 |
+
with open(file_name, 'r', encoding=encoding) as fr:
|
326 |
+
return json.load(fr)
|
327 |
+
|
328 |
+
@staticmethod
|
329 |
+
def write_str(s, path):
|
330 |
+
with open(path, 'w') as f:
|
331 |
+
f.write(s)
|
332 |
+
|
333 |
+
@staticmethod
|
334 |
+
def load_str(path):
|
335 |
+
with open(path, 'r') as f:
|
336 |
+
return '\n'.join(f.readlines())
|
337 |
+
|
338 |
+
@staticmethod
|
339 |
+
def dump_pickle(obj, path):
|
340 |
+
with open(path, 'wb') as f:
|
341 |
+
pickle.dump(obj, f)
|
342 |
+
|
343 |
+
@staticmethod
|
344 |
+
def load_pickle(path):
|
345 |
+
with open(path, 'rb') as f:
|
346 |
+
return pickle.load(f)
|
347 |
+
|
348 |
+
|
349 |
+
class ArticleGenerationModule():
|
350 |
+
"""
|
351 |
+
The interface for article generation stage. Given topic, collected information from
|
352 |
+
knowledge curation stage, generated outline from outline generation stage,
|
353 |
+
"""
|
354 |
+
|
355 |
+
def __init__(self,
|
356 |
+
retriever,
|
357 |
+
article_gen_lm=Union[dspy.dsp.LM, dspy.dsp.HFModel],
|
358 |
+
retrieve_top_k: int = 10,
|
359 |
+
max_thread_num: int = 10,
|
360 |
+
):
|
361 |
+
super().__init__()
|
362 |
+
self.retrieve_top_k = retrieve_top_k
|
363 |
+
self.article_gen_lm = article_gen_lm
|
364 |
+
self.max_thread_num = max_thread_num
|
365 |
+
self.retriever = retriever
|
366 |
+
self.section_gen = ConvToSection(engine=self.article_gen_lm)
|
367 |
+
|
368 |
+
def generate_section(self, topic, section_name, mindmap, section_query, section_outline):
|
369 |
+
collected_info = mindmap.retrieve_information(queries=section_query,
|
370 |
+
search_top_k=self.retrieve_top_k)
|
371 |
+
output = self.section_gen(
|
372 |
+
topic=topic,
|
373 |
+
outline=section_outline,
|
374 |
+
section=section_name,
|
375 |
+
collected_info=collected_info,
|
376 |
+
)
|
377 |
+
|
378 |
+
return {"section_name": section_name, "section_content": output.section, "collected_info": collected_info}
|
379 |
+
|
380 |
+
def generate_article(self,
|
381 |
+
topic: str,
|
382 |
+
mindmap,
|
383 |
+
article_with_outline,
|
384 |
+
):
|
385 |
+
"""
|
386 |
+
Generate article for the topic based on the information table and article outline.
|
387 |
+
"""
|
388 |
+
mindmap.prepare_table_for_retrieval()
|
389 |
+
|
390 |
+
sections_to_write = article_with_outline.get_first_level_section_names()
|
391 |
+
section_output_dict_collection = []
|
392 |
+
|
393 |
+
with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_thread_num) as executor:
|
394 |
+
future_to_sec_title = {}
|
395 |
+
for section_title in sections_to_write:
|
396 |
+
section_query = article_with_outline.get_outline_as_list(
|
397 |
+
root_section_name=section_title, add_hashtags=False
|
398 |
+
)
|
399 |
+
queries_with_hashtags = article_with_outline.get_outline_as_list(
|
400 |
+
root_section_name=section_title, add_hashtags=True
|
401 |
+
)
|
402 |
+
section_outline = "\n".join(queries_with_hashtags)
|
403 |
+
|
404 |
+
future_to_sec_title[
|
405 |
+
executor.submit(self.generate_section,
|
406 |
+
topic, section_title, mindmap, section_query,section_outline)
|
407 |
+
] = section_title
|
408 |
+
|
409 |
+
for future in concurrent.futures.as_completed(future_to_sec_title):
|
410 |
+
section_output_dict_collection.append(future.result())
|
411 |
+
|
412 |
+
article = copy.deepcopy(article_with_outline)
|
413 |
+
for section_output_dict in section_output_dict_collection:
|
414 |
+
article.update_section(parent_section_name=topic,
|
415 |
+
current_section_content=section_output_dict["section_content"],
|
416 |
+
current_section_info_list=section_output_dict["collected_info"],
|
417 |
+
)
|
418 |
+
|
419 |
+
article.post_processing()
|
420 |
+
|
421 |
+
|
422 |
+
|
423 |
+
return article
|
424 |
+
|
425 |
+
class ConvToSection(dspy.Module):
|
426 |
+
"""Use the information collected from the information-seeking conversation to write a section."""
|
427 |
+
#给你传入的都是所有的section的对应的url,但是这个地方我们的目标是指根据一个来生成,这个地方需要完善,因为他的outline没有用到
|
428 |
+
def __init__(self, engine: Union[dspy.dsp.LM, dspy.dsp.HFModel]):
|
429 |
+
super().__init__()
|
430 |
+
self.write_section = dspy.Predict(WriteSection)
|
431 |
+
self.engine = engine
|
432 |
+
|
433 |
+
def forward(self, topic: str, outline:str, section: str, collected_info: List):
|
434 |
+
all_info = ''
|
435 |
+
for idx, info in enumerate(collected_info):
|
436 |
+
all_info += f'[{idx + 1}]\n' + '\n'.join(info['snippets'])
|
437 |
+
all_info += '\n\n'
|
438 |
+
|
439 |
+
all_info = ArticleTextProcessing.limit_word_count_preserve_newline(all_info, 1500)
|
440 |
+
|
441 |
+
with dspy.settings.context(lm=self.engine):
|
442 |
+
section = ArticleTextProcessing.clean_up_section(
|
443 |
+
self.write_section(topic=topic, info=info, section=section).output)
|
444 |
+
|
445 |
+
section = section.replace('\[','[').replace('\]',']')
|
446 |
+
return dspy.Prediction(section=section)
|
447 |
+
|
448 |
+
|
449 |
+
|
450 |
+
class WriteSection(dspy.Signature):
|
451 |
+
"""Write a Wikipedia section based on the collected information.
|
452 |
+
|
453 |
+
Here is the format of your writing:
|
454 |
+
1. Use "#" Title" to indicate section title, "##" Title" to indicate subsection title, "###" Title" to indicate subsubsection title, and so on.
|
455 |
+
2. Use [1], [2], ..., [n] in line (for example, "The capital of the United States is Washington, D.C.[1][3]."). You DO NOT need to include a References or Sources section to list the sources at the end.
|
456 |
+
3. The language style should resemble that of Wikipedia: concise yet informative, formal yet accessible.
|
457 |
+
"""
|
458 |
+
# """
|
459 |
+
# Write a detailed, Wikipedia-style report section based on the collected information.
|
460 |
+
|
461 |
+
# Here is the format of your writing:
|
462 |
+
# 1. Use "#" Title" to indicate section title, "##" Title" to indicate subsection title, "###" Title" to indicate subsubsection title, and so on.
|
463 |
+
# 2. Use [1], [2], ..., [n] in line (for example, "The capital of the United States is Washington, D.C.[1][3]."). You DO NOT need to include a References or Sources section to list the sources at the end.
|
464 |
+
# 3. The language style should resemble that of Wikipedia: concise yet informative, formal yet accessible.
|
465 |
+
# """
|
466 |
+
|
467 |
+
info = dspy.InputField(prefix="The Collected information:\n", format=str)
|
468 |
+
topic = dspy.InputField(prefix="The topic of the page: ", format=str)
|
469 |
+
section = dspy.InputField(prefix="The section you need to write: ", format=str)
|
470 |
+
output = dspy.OutputField(
|
471 |
+
prefix="Write the section with proper inline citations (Start your writing with # section title. Don't include the page title or try to write other sections):\n",
|
472 |
+
format=str)
|
473 |
+
|
474 |
+
|
475 |
+
|
476 |
+
if __name__ == "__main__":
|
477 |
+
import sys
|
478 |
+
from mindmap import MindMap
|
479 |
+
from outline_generation import OutlineGenerationModule
|
480 |
+
sys.path.append('/mnt/nas-alinlp/xizekun/project/DeepThink/src')
|
481 |
+
from storm_dataclass import Article
|
482 |
+
|
483 |
+
from lm import OpenAIModel, OpenAIModel_New
|
484 |
+
from rm import BingSearch, BingSearchAli
|
485 |
+
from utils import load_api_key
|
486 |
+
import os
|
487 |
+
load_api_key(toml_file_path='/mnt/nas-alinlp/xizekun/project/DeepThink/secrets.toml')
|
488 |
+
openai_kwargs = {
|
489 |
+
'api_key': os.getenv("OPENAI_API_KEY"),
|
490 |
+
'api_provider': os.getenv('OPENAI_API_TYPE'),
|
491 |
+
'temperature': 1.0,
|
492 |
+
'top_p': 0.9,
|
493 |
+
'api_base': os.getenv('AZURE_API_BASE'),
|
494 |
+
'api_version': os.getenv('AZURE_API_VERSION'),
|
495 |
+
}
|
496 |
+
|
497 |
+
lm = OpenAIModel(model='gpt-4-1106-preview', max_tokens=5000, **openai_kwargs)
|
498 |
+
rm = BingSearchAli(ydc_api_key=os.getenv('BING_SEARCH_ALI_API_KEY'), k=3)
|
499 |
+
|
500 |
+
retriever = rm
|
501 |
+
gen_concept_lm = lm
|
502 |
+
|
503 |
+
mind_map = MindMap(
|
504 |
+
retriever=retriever,
|
505 |
+
gen_concept_lm=lm,
|
506 |
+
search_top_k=3,
|
507 |
+
deepth = 3
|
508 |
+
)
|
509 |
+
a = mind_map.load_map('/mnt/nas-alinlp/xizekun/project/DeepThink/src/DeepThink/modules/Taylor.json')
|
510 |
+
ag = ArticleGenerationModule(
|
511 |
+
retriever = retriever,
|
512 |
+
article_gen_lm = lm,
|
513 |
+
retrieve_top_k = 5,
|
514 |
+
max_thread_num = 10)
|
515 |
+
|
516 |
+
module = OutlineGenerationModule(lm)
|
517 |
+
outline = module.generate_outline(topic= 'Taylor Hawkins',mindmap = mind_map)
|
518 |
+
print(outline)
|
519 |
+
print('~~~~~~')
|
520 |
+
|
521 |
+
article_with_outline = Article.from_outline_str(topic='Taylor Hawkins', outline_str=outline)
|
522 |
+
a = ag.generate_article(topic = 'Taylor Hawkins', mindmap = mind_map, article_with_outline = article_with_outline)
|
523 |
+
print(a.to_string())
|
src/DeepThink/modules/article_polish.py
ADDED
@@ -0,0 +1,417 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
from typing import Union
|
3 |
+
|
4 |
+
import dspy
|
5 |
+
# from storm_wiki.modules.storm_dataclass import StormArticle
|
6 |
+
import concurrent.futures
|
7 |
+
import json
|
8 |
+
import os
|
9 |
+
import pickle
|
10 |
+
import re
|
11 |
+
import sys
|
12 |
+
from typing import List, Dict
|
13 |
+
|
14 |
+
import httpx
|
15 |
+
import toml
|
16 |
+
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
17 |
+
from trafilatura import extract
|
18 |
+
|
19 |
+
|
20 |
+
class ArticleTextProcessing:
|
21 |
+
@staticmethod
|
22 |
+
def limit_word_count_preserve_newline(input_string, max_word_count):
|
23 |
+
"""
|
24 |
+
Limit the word count of an input string to a specified maximum, while preserving the integrity of complete lines.
|
25 |
+
|
26 |
+
The function truncates the input string at the nearest word that does not exceed the maximum word count,
|
27 |
+
ensuring that no partial lines are included in the output. Words are defined as text separated by spaces,
|
28 |
+
and lines are defined as text separated by newline characters.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
input_string (str): The string to be truncated. This string may contain multiple lines.
|
32 |
+
max_word_count (int): The maximum number of words allowed in the truncated string.
|
33 |
+
|
34 |
+
Returns:
|
35 |
+
str: The truncated string with word count limited to `max_word_count`, preserving complete lines.
|
36 |
+
"""
|
37 |
+
|
38 |
+
word_count = 0
|
39 |
+
limited_string = ''
|
40 |
+
|
41 |
+
for word in input_string.split('\n'):
|
42 |
+
line_words = word.split()
|
43 |
+
for lw in line_words:
|
44 |
+
if word_count < max_word_count:
|
45 |
+
limited_string += lw + ' '
|
46 |
+
word_count += 1
|
47 |
+
else:
|
48 |
+
break
|
49 |
+
if word_count >= max_word_count:
|
50 |
+
break
|
51 |
+
limited_string = limited_string.strip() + '\n'
|
52 |
+
|
53 |
+
return limited_string.strip()
|
54 |
+
|
55 |
+
@staticmethod
|
56 |
+
def remove_citations(s):
|
57 |
+
"""
|
58 |
+
Removes all citations from a given string. Citations are assumed to be in the format
|
59 |
+
of numbers enclosed in square brackets, such as [1], [2], or [1, 2], etc. This function searches
|
60 |
+
for all occurrences of such patterns and removes them, returning the cleaned string.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
s (str): The string from which citations are to be removed.
|
64 |
+
|
65 |
+
Returns:
|
66 |
+
str: The string with all citation patterns removed.
|
67 |
+
"""
|
68 |
+
|
69 |
+
return re.sub(r'\[\d+(?:,\s*\d+)*\]', '', s)
|
70 |
+
|
71 |
+
@staticmethod
|
72 |
+
def get_first_section_dict_and_list(s):
|
73 |
+
"""
|
74 |
+
"""
|
75 |
+
text = s
|
76 |
+
sections = text.strip().split('\n# ')
|
77 |
+
titles = []
|
78 |
+
content_dict = {}
|
79 |
+
|
80 |
+
for section in sections:
|
81 |
+
if section:
|
82 |
+
lines = section.split('\n', 1)
|
83 |
+
title = lines[0].strip()
|
84 |
+
content = lines[1].strip() if len(lines) > 1 else ""
|
85 |
+
|
86 |
+
titles.append(title)
|
87 |
+
content_dict[title] = content
|
88 |
+
return content_dict, titles
|
89 |
+
|
90 |
+
@staticmethod
|
91 |
+
def parse_citation_indices(s):
|
92 |
+
"""
|
93 |
+
Extracts citation indexes from the provided content string and returns them as a list of integers.
|
94 |
+
|
95 |
+
Args:
|
96 |
+
content (str): The content string containing citations in the format [number].
|
97 |
+
|
98 |
+
Returns:
|
99 |
+
List[int]: A list of unique citation indexes extracted from the content, in the order they appear.
|
100 |
+
"""
|
101 |
+
matches = re.findall(r'\[\d+\]', s)
|
102 |
+
return [int(index[1:-1]) for index in matches]
|
103 |
+
|
104 |
+
@staticmethod
|
105 |
+
def remove_uncompleted_sentences_with_citations(text):
|
106 |
+
"""
|
107 |
+
Removes uncompleted sentences and standalone citations from the input text. Sentences are identified
|
108 |
+
by their ending punctuation (.!?), optionally followed by a citation in square brackets (e.g., "[1]").
|
109 |
+
Grouped citations (e.g., "[1, 2]") are split into individual ones (e.g., "[1] [2]"). Only text up to
|
110 |
+
and including the last complete sentence and its citation is retained.
|
111 |
+
|
112 |
+
Args:
|
113 |
+
text (str): The input text from which uncompleted sentences and their citations are to be removed.
|
114 |
+
|
115 |
+
Returns:
|
116 |
+
str: The processed string with uncompleted sentences and standalone citations removed, leaving only
|
117 |
+
complete sentences and their associated citations if present.
|
118 |
+
"""
|
119 |
+
|
120 |
+
# Convert citations like [1, 2, 3] to [1][2][3].
|
121 |
+
def replace_with_individual_brackets(match):
|
122 |
+
numbers = match.group(1).split(', ')
|
123 |
+
return ' '.join(f'[{n}]' for n in numbers)
|
124 |
+
|
125 |
+
# Deduplicate and sort individual groups of citations.
|
126 |
+
def deduplicate_group(match):
|
127 |
+
citations = match.group(0)
|
128 |
+
unique_citations = list(set(re.findall(r'\[\d+\]', citations)))
|
129 |
+
sorted_citations = sorted(unique_citations, key=lambda x: int(x.strip('[]')))
|
130 |
+
# Return the sorted unique citations as a string
|
131 |
+
return ''.join(sorted_citations)
|
132 |
+
|
133 |
+
text = re.sub(r'\[([0-9, ]+)\]', replace_with_individual_brackets, text)
|
134 |
+
text = re.sub(r'(\[\d+\])+', deduplicate_group, text)
|
135 |
+
|
136 |
+
# Deprecated: Remove sentence without proper ending punctuation and citations.
|
137 |
+
# Split the text into sentences (including citations).
|
138 |
+
# sentences_with_trailing = re.findall(r'([^.!?]*[.!?].*?)(?=[^.!?]*[.!?]|$)', text)
|
139 |
+
|
140 |
+
# Filter sentences to ensure they end with a punctuation mark and properly formatted citations
|
141 |
+
# complete_sentences = []
|
142 |
+
# for sentence in sentences_with_trailing:
|
143 |
+
# # Check if the sentence ends with properly formatted citations
|
144 |
+
# if re.search(r'[.!?]( \[\d+\])*$|^[^.!?]*[.!?]$', sentence.strip()):
|
145 |
+
# complete_sentences.append(sentence.strip())
|
146 |
+
|
147 |
+
# combined_sentences = ' '.join(complete_sentences)
|
148 |
+
|
149 |
+
# Check for and append any complete citations that follow the last sentence
|
150 |
+
# trailing_citations = re.findall(r'(\[\d+\]) ', text[text.rfind(combined_sentences) + len(combined_sentences):])
|
151 |
+
# if trailing_citations:
|
152 |
+
# combined_sentences += ' '.join(trailing_citations)
|
153 |
+
|
154 |
+
# Regex pattern to match sentence endings, including optional citation markers.
|
155 |
+
eos_pattern = r'([.!?])\s*(\[\d+\])?\s*'
|
156 |
+
matches = list(re.finditer(eos_pattern, text))
|
157 |
+
if matches:
|
158 |
+
last_match = matches[-1]
|
159 |
+
text = text[:last_match.end()].strip()
|
160 |
+
|
161 |
+
return text
|
162 |
+
|
163 |
+
@staticmethod
|
164 |
+
def clean_up_citation(conv):
|
165 |
+
for turn in conv.dlg_history:
|
166 |
+
turn.agent_utterance = turn.agent_utterance[:turn.agent_utterance.find('References:')]
|
167 |
+
turn.agent_utterance = turn.agent_utterance[:turn.agent_utterance.find('Sources:')]
|
168 |
+
turn.agent_utterance = turn.agent_utterance.replace('Answer:', '').strip()
|
169 |
+
try:
|
170 |
+
max_ref_num = max([int(x) for x in re.findall(r'\[(\d+)\]', turn.agent_utterance)])
|
171 |
+
except Exception as e:
|
172 |
+
max_ref_num = 0
|
173 |
+
if max_ref_num > len(turn.search_results):
|
174 |
+
for i in range(len(turn.search_results), max_ref_num + 1):
|
175 |
+
turn.agent_utterance = turn.agent_utterance.replace(f'[{i}]', '')
|
176 |
+
turn.agent_utterance = ArticleTextProcessing.remove_uncompleted_sentences_with_citations(
|
177 |
+
turn.agent_utterance)
|
178 |
+
|
179 |
+
return conv
|
180 |
+
|
181 |
+
@staticmethod
|
182 |
+
def clean_up_outline(outline, topic=""):
|
183 |
+
output_lines = []
|
184 |
+
current_level = 0 # To track the current section level
|
185 |
+
|
186 |
+
for line in outline.split('\n'):
|
187 |
+
stripped_line = line.strip()
|
188 |
+
|
189 |
+
if topic != "" and f"# {topic.lower()}" in stripped_line.lower():
|
190 |
+
output_lines = []
|
191 |
+
|
192 |
+
# Check if the line is a section header
|
193 |
+
if stripped_line.startswith('#') and stripped_line != '#':
|
194 |
+
current_level = stripped_line.count('#')
|
195 |
+
output_lines.append(stripped_line)
|
196 |
+
# Check if the line is a bullet point
|
197 |
+
# elif stripped_line.startswith('-'):
|
198 |
+
# subsection_header = '#' * (current_level + 1) + ' ' + stripped_line[1:].strip()
|
199 |
+
# output_lines.append(subsection_header)
|
200 |
+
# Preserve lines with @
|
201 |
+
elif stripped_line.startswith('@'):
|
202 |
+
output_lines.append(stripped_line)
|
203 |
+
|
204 |
+
outline = '\n'.join(output_lines)
|
205 |
+
|
206 |
+
# Remove references.
|
207 |
+
outline = re.sub(r"#[#]? See also.*?(?=##|$)", '', outline, flags=re.DOTALL)
|
208 |
+
outline = re.sub(r"#[#]? See Also.*?(?=##|$)", '', outline, flags=re.DOTALL)
|
209 |
+
outline = re.sub(r"#[#]? Notes.*?(?=##|$)", '', outline, flags=re.DOTALL)
|
210 |
+
outline = re.sub(r"#[#]? References.*?(?=##|$)", '', outline, flags=re.DOTALL)
|
211 |
+
outline = re.sub(r"#[#]? External links.*?(?=##|$)", '', outline, flags=re.DOTALL)
|
212 |
+
outline = re.sub(r"#[#]? External Links.*?(?=##|$)", '', outline, flags=re.DOTALL)
|
213 |
+
outline = re.sub(r"#[#]? Bibliography.*?(?=##|$)", '', outline, flags=re.DOTALL)
|
214 |
+
outline = re.sub(r"#[#]? Further reading*?(?=##|$)", '', outline, flags=re.DOTALL)
|
215 |
+
outline = re.sub(r"#[#]? Further Reading*?(?=##|$)", '', outline, flags=re.DOTALL)
|
216 |
+
outline = re.sub(r"#[#]? Summary.*?(?=##|$)", '', outline, flags=re.DOTALL)
|
217 |
+
outline = re.sub(r"#[#]? Appendices.*?(?=##|$)", '', outline, flags=re.DOTALL)
|
218 |
+
outline = re.sub(r"#[#]? Appendix.*?(?=##|$)", '', outline, flags=re.DOTALL)
|
219 |
+
|
220 |
+
return outline
|
221 |
+
|
222 |
+
|
223 |
+
@staticmethod
|
224 |
+
def clean_up_section(text):
|
225 |
+
"""Clean up a section:
|
226 |
+
1. Remove uncompleted sentences (usually due to output token limitation).
|
227 |
+
2. Deduplicate individual groups of citations.
|
228 |
+
3. Remove unnecessary summary."""
|
229 |
+
|
230 |
+
paragraphs = text.split('\n')
|
231 |
+
output_paragraphs = []
|
232 |
+
summary_sec_flag = False
|
233 |
+
for p in paragraphs:
|
234 |
+
p = p.strip()
|
235 |
+
if len(p) == 0:
|
236 |
+
continue
|
237 |
+
if not p.startswith('#'):
|
238 |
+
p = ArticleTextProcessing.remove_uncompleted_sentences_with_citations(p)
|
239 |
+
if summary_sec_flag:
|
240 |
+
if p.startswith('#'):
|
241 |
+
summary_sec_flag = False
|
242 |
+
else:
|
243 |
+
continue
|
244 |
+
if p.startswith('Overall') or p.startswith('In summary') or p.startswith('In conclusion'):
|
245 |
+
continue
|
246 |
+
if "# Summary" in p or '# Conclusion' in p:
|
247 |
+
summary_sec_flag = True
|
248 |
+
continue
|
249 |
+
output_paragraphs.append(p)
|
250 |
+
|
251 |
+
return '\n\n'.join(output_paragraphs) # Join with '\n\n' for markdown format.
|
252 |
+
|
253 |
+
@staticmethod
|
254 |
+
def update_citation_index(s, citation_map):
|
255 |
+
"""Update citation index in the string based on the citation map."""
|
256 |
+
for original_citation in citation_map:
|
257 |
+
s = s.replace(f"[{original_citation}]", f"__PLACEHOLDER_{original_citation}__")
|
258 |
+
for original_citation, unify_citation in citation_map.items():
|
259 |
+
s = s.replace(f"__PLACEHOLDER_{original_citation}__", f"[{unify_citation}]")
|
260 |
+
|
261 |
+
return s
|
262 |
+
|
263 |
+
@staticmethod
|
264 |
+
def parse_article_into_dict(input_string):
|
265 |
+
"""
|
266 |
+
Parses a structured text into a nested dictionary. The structure of the text
|
267 |
+
is defined by markdown-like headers (using '#' symbols) to denote sections
|
268 |
+
and subsections. Each section can contain content and further nested subsections.
|
269 |
+
|
270 |
+
The resulting dictionary captures the hierarchical structure of sections, where
|
271 |
+
each section is represented as a key (the section's title) mapping to a value
|
272 |
+
that is another dictionary. This dictionary contains two keys:
|
273 |
+
- 'content': content of the section
|
274 |
+
- 'subsections': a list of dictionaries, each representing a nested subsection
|
275 |
+
following the same structure.
|
276 |
+
|
277 |
+
Args:
|
278 |
+
input_string (str): A string containing the structured text to parse.
|
279 |
+
|
280 |
+
Returns:
|
281 |
+
A dictionary representing contains the section title as the key, and another dictionary
|
282 |
+
as the value, which includes the 'content' and 'subsections' keys as described above.
|
283 |
+
"""
|
284 |
+
lines = input_string.split('\n')
|
285 |
+
lines = [line for line in lines if line.strip()]
|
286 |
+
root = {'content': '', 'subsections': {}}
|
287 |
+
current_path = [(root, -1)] # (current_dict, level)
|
288 |
+
|
289 |
+
for line in lines:
|
290 |
+
if line.startswith('#'):
|
291 |
+
level = line.count('#')
|
292 |
+
title = line.strip('# ').strip()
|
293 |
+
new_section = {'content': '', 'subsections': {}}
|
294 |
+
|
295 |
+
# Pop from stack until find the parent level
|
296 |
+
while current_path and current_path[-1][1] >= level:
|
297 |
+
current_path.pop()
|
298 |
+
|
299 |
+
# Append new section to the nearest upper level's subsections
|
300 |
+
current_path[-1][0]['subsections'][title] = new_section
|
301 |
+
current_path.append((new_section, level))
|
302 |
+
else:
|
303 |
+
current_path[-1][0]['content'] += line + '\n'
|
304 |
+
|
305 |
+
return root['subsections']
|
306 |
+
|
307 |
+
|
308 |
+
class FileIOHelper:
|
309 |
+
@staticmethod
|
310 |
+
def dump_json(obj, file_name, encoding="utf-8"):
|
311 |
+
with open(file_name, 'w', encoding=encoding) as fw:
|
312 |
+
json.dump(obj, fw, default=FileIOHelper.handle_non_serializable, ensure_ascii=False)
|
313 |
+
|
314 |
+
@staticmethod
|
315 |
+
def handle_non_serializable(obj):
|
316 |
+
return "non-serializable contents" # mark the non-serializable part
|
317 |
+
|
318 |
+
@staticmethod
|
319 |
+
def load_json(file_name, encoding="utf-8"):
|
320 |
+
with open(file_name, 'r', encoding=encoding) as fr:
|
321 |
+
return json.load(fr)
|
322 |
+
|
323 |
+
@staticmethod
|
324 |
+
def write_str(s, path):
|
325 |
+
with open(path, 'w') as f:
|
326 |
+
f.write(s)
|
327 |
+
|
328 |
+
@staticmethod
|
329 |
+
def load_str(path):
|
330 |
+
with open(path, 'r') as f:
|
331 |
+
return '\n'.join(f.readlines())
|
332 |
+
|
333 |
+
@staticmethod
|
334 |
+
def dump_pickle(obj, path):
|
335 |
+
with open(path, 'wb') as f:
|
336 |
+
pickle.dump(obj, f)
|
337 |
+
|
338 |
+
@staticmethod
|
339 |
+
def load_pickle(path):
|
340 |
+
with open(path, 'rb') as f:
|
341 |
+
return pickle.load(f)
|
342 |
+
|
343 |
+
class ArticlePolishingModule():
|
344 |
+
"""
|
345 |
+
The interface for article generation stage. Given topic, collected information from
|
346 |
+
knowledge curation stage, generated outline from outline generation stage.
|
347 |
+
"""
|
348 |
+
|
349 |
+
def __init__(self,
|
350 |
+
article_gen_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel],
|
351 |
+
article_polish_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel]):
|
352 |
+
self.article_gen_lm = article_gen_lm
|
353 |
+
self.article_polish_lm = article_polish_lm
|
354 |
+
|
355 |
+
self.polish_page = PolishPageModule(
|
356 |
+
write_lead_engine=self.article_gen_lm,
|
357 |
+
polish_engine=self.article_polish_lm
|
358 |
+
)
|
359 |
+
|
360 |
+
def polish_article(self,
|
361 |
+
topic: str,
|
362 |
+
draft_article,
|
363 |
+
remove_duplicate: bool = False):
|
364 |
+
"""
|
365 |
+
Polish article.
|
366 |
+
|
367 |
+
Args:
|
368 |
+
topic (str): The topic of the article.
|
369 |
+
draft_article (StormArticle): The draft article.
|
370 |
+
remove_duplicate (bool): Whether to use one additional LM call to remove duplicates from the article.
|
371 |
+
"""
|
372 |
+
|
373 |
+
article_text = draft_article.to_string()
|
374 |
+
remove_duplicate = True
|
375 |
+
polish_result = self.polish_page(topic=topic, draft_page=article_text, polish_whole_page=remove_duplicate)
|
376 |
+
|
377 |
+
polished_article = polish_result.page
|
378 |
+
|
379 |
+
polished_article_dict = ArticleTextProcessing.parse_article_into_dict(polished_article)
|
380 |
+
polished_article = copy.deepcopy(draft_article)
|
381 |
+
polished_article.insert_or_create_section(article_dict=polished_article_dict)
|
382 |
+
polished_article.post_processing()
|
383 |
+
return polished_article
|
384 |
+
|
385 |
+
|
386 |
+
|
387 |
+
class PolishPage(dspy.Signature):
|
388 |
+
"""
|
389 |
+
You are a faithful text editor that is good at finding repeated information in the article and deleting them to make sure there is no repetition in the article.
|
390 |
+
You won't delete any non-repeated part in the article.
|
391 |
+
You will keep the inline citations and article structure (indicated by "#", "##", etc.) appropriately.
|
392 |
+
In the article, do not include references.
|
393 |
+
Do your job for the following article.
|
394 |
+
"""
|
395 |
+
|
396 |
+
article = dspy.InputField(prefix="The article you need to polish:\n", format=str)
|
397 |
+
page = dspy.OutputField(
|
398 |
+
prefix="Your revised article:\n",
|
399 |
+
format=str)
|
400 |
+
|
401 |
+
|
402 |
+
class PolishPageModule(dspy.Module):
|
403 |
+
def __init__(self, write_lead_engine: Union[dspy.dsp.LM, dspy.dsp.HFModel],
|
404 |
+
polish_engine: Union[dspy.dsp.LM, dspy.dsp.HFModel]):
|
405 |
+
super().__init__()
|
406 |
+
self.write_lead_engine = write_lead_engine
|
407 |
+
self.polish_engine = polish_engine
|
408 |
+
self.polish_page = dspy.Predict(PolishPage)
|
409 |
+
|
410 |
+
def forward(self, topic: str, draft_page: str, polish_whole_page: bool = True):
|
411 |
+
|
412 |
+
with dspy.settings.context(lm=self.polish_engine):
|
413 |
+
page = self.polish_page(article=draft_page).page
|
414 |
+
|
415 |
+
return dspy.Prediction(page=page)
|
416 |
+
|
417 |
+
|