ZekunXi commited on
Commit
80a598c
·
1 Parent(s): d1dbd31
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +161 -0
  2. model/paraphrase-MiniLM-L6-v2/.gitattributes +17 -0
  3. model/paraphrase-MiniLM-L6-v2/1_Pooling/config.json +7 -0
  4. model/paraphrase-MiniLM-L6-v2/README.md +108 -0
  5. model/paraphrase-MiniLM-L6-v2/config.json +24 -0
  6. model/paraphrase-MiniLM-L6-v2/config_sentence_transformers.json +7 -0
  7. model/paraphrase-MiniLM-L6-v2/model.safetensors +3 -0
  8. model/paraphrase-MiniLM-L6-v2/modules.json +14 -0
  9. model/paraphrase-MiniLM-L6-v2/onnx/model.onnx +3 -0
  10. model/paraphrase-MiniLM-L6-v2/onnx/model_O1.onnx +3 -0
  11. model/paraphrase-MiniLM-L6-v2/onnx/model_O2.onnx +3 -0
  12. model/paraphrase-MiniLM-L6-v2/onnx/model_O3.onnx +3 -0
  13. model/paraphrase-MiniLM-L6-v2/onnx/model_O4.onnx +3 -0
  14. model/paraphrase-MiniLM-L6-v2/onnx/model_qint8_arm64.onnx +3 -0
  15. model/paraphrase-MiniLM-L6-v2/onnx/model_qint8_avx512.onnx +3 -0
  16. model/paraphrase-MiniLM-L6-v2/onnx/model_qint8_avx512_vnni.onnx +3 -0
  17. model/paraphrase-MiniLM-L6-v2/onnx/model_quint8_avx2.onnx +3 -0
  18. model/paraphrase-MiniLM-L6-v2/openvino/openvino_model.bin +3 -0
  19. model/paraphrase-MiniLM-L6-v2/openvino/openvino_model.xml +0 -0
  20. model/paraphrase-MiniLM-L6-v2/openvino/openvino_model_qint8_quantized.bin +3 -0
  21. model/paraphrase-MiniLM-L6-v2/openvino/openvino_model_qint8_quantized.xml +0 -0
  22. model/paraphrase-MiniLM-L6-v2/pytorch_model.bin +3 -0
  23. model/paraphrase-MiniLM-L6-v2/sentence_bert_config.json +4 -0
  24. model/paraphrase-MiniLM-L6-v2/special_tokens_map.json +1 -0
  25. model/paraphrase-MiniLM-L6-v2/tf_model.h5 +3 -0
  26. model/paraphrase-MiniLM-L6-v2/tokenizer.json +0 -0
  27. model/paraphrase-MiniLM-L6-v2/tokenizer_config.json +1 -0
  28. model/paraphrase-MiniLM-L6-v2/vocab.txt +0 -0
  29. requirements.txt +22 -0
  30. src/DeepThink/__pycache__/__init__.cpython-311.pyc +0 -0
  31. src/DeepThink/__pycache__/engine.cpython-311.pyc +0 -0
  32. src/DeepThink/engine.py +285 -0
  33. src/DeepThink/modules/__pycache__/__init__.cpython-311.pyc +0 -0
  34. src/DeepThink/modules/__pycache__/article_generation.cpython-310.pyc +0 -0
  35. src/DeepThink/modules/__pycache__/article_generation.cpython-311.pyc +0 -0
  36. src/DeepThink/modules/__pycache__/article_polish.cpython-310.pyc +0 -0
  37. src/DeepThink/modules/__pycache__/article_polish.cpython-311.pyc +0 -0
  38. src/DeepThink/modules/__pycache__/interface.cpython-310.pyc +0 -0
  39. src/DeepThink/modules/__pycache__/interface.cpython-311.pyc +0 -0
  40. src/DeepThink/modules/__pycache__/mindmap.cpython-310.pyc +0 -0
  41. src/DeepThink/modules/__pycache__/mindmap.cpython-311.pyc +0 -0
  42. src/DeepThink/modules/__pycache__/outline_generation.cpython-310.pyc +0 -0
  43. src/DeepThink/modules/__pycache__/outline_generation.cpython-311.pyc +0 -0
  44. src/DeepThink/modules/__pycache__/retriever.cpython-311.pyc +0 -0
  45. src/DeepThink/modules/__pycache__/storm_dataclass.cpython-310.pyc +0 -0
  46. src/DeepThink/modules/__pycache__/storm_dataclass.cpython-311.pyc +0 -0
  47. src/DeepThink/modules/__pycache__/utils.cpython-310.pyc +0 -0
  48. src/DeepThink/modules/__pycache__/utils.cpython-311.pyc +0 -0
  49. src/DeepThink/modules/article_generation.py +523 -0
  50. src/DeepThink/modules/article_polish.py +417 -0
app.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import json
3
+
4
+ import pandas as pd
5
+ import streamlit as st
6
+ from dotenv import load_dotenv
7
+ from http import HTTPStatus
8
+
9
+ from src.lm import QwenModel
10
+ from src.rm import GoogleSearchAli_new
11
+ import sys
12
+ sys.path.append('./src/DeepThink/modules')
13
+ from mindmap import MindMap
14
+ from storm_dataclass import Article
15
+
16
+ from article_generation import ArticleGenerationModule
17
+ from article_polish import ArticlePolishingModule
18
+ from outline_generation import OutlineGenerationModule
19
+
20
+ import os
21
+
22
+ import subprocess
23
+ bash_command = "pip install --upgrade pip"
24
+ process = subprocess.Popen(bash_command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
25
+
26
+ # Load environment variables and API keys
27
+ # load_dotenv()
28
+
29
+ openai_kwargs = {
30
+ 'api_key': os.getenv("OPENAI_API_KEY"),
31
+ 'api_provider': os.getenv('OPENAI_API_TYPE'),
32
+ 'temperature': 1.0,
33
+ 'top_p': 0.9,
34
+ 'api_base': os.getenv('AZURE_API_BASE'),
35
+ 'api_version': os.getenv('AZURE_API_VERSION'),
36
+ }
37
+
38
+
39
+ lm = QwenModel(model='qwen-plus', max_tokens=1000, **openai_kwargs)
40
+ lm4outline = QwenModel(model='qwen-plus', max_tokens=1000, **openai_kwargs)
41
+ lm4gensection = QwenModel(model='qwen-plus', max_tokens=2000, **openai_kwargs)
42
+ lm4polish = QwenModel(model='qwen-plus', max_tokens=4000, **openai_kwargs)
43
+
44
+
45
+ rm = GoogleSearchAli_new(k=5)
46
+
47
+
48
+
49
+
50
+ st.set_page_config(page_title='OmniThink', layout="wide")
51
+
52
+
53
+
54
+
55
+
56
+
57
+ st.warning("Announcement: Due to the recent high volume of visitors, search API quota limitations, you may encounter an error: "
58
+ "'ValueError: Expected 2D array, got 1D array instead: array=[]. "
59
+ "Reshape your data either using array.reshape(-1, 1) if your data has a single feature "
60
+ "or array.reshape(1, -1) if it contains a single sample.' "
61
+ "If this error occurs, please try again in a few hours.")
62
+
63
+ st.title('🤔 OmniThink')
64
+ st.markdown('_OmniThink is a tool that helps you think deeply about a topic, generate an outline, and write an article._')
65
+
66
+
67
+
68
+
69
+
70
+ # Sidebar for configuration and examples
71
+ with st.sidebar:
72
+ st.header('Configuration')
73
+ MAX_ROUNDS = st.number_input('Retrieval Depth', min_value=0, max_value=10, value=2, step=1)
74
+ models = ['Qwen-Plus', 'Coming Soon']
75
+ selected_example = st.selectbox('LLM:', models)
76
+ searchers = ['GoogleSearch', 'Coming Soon']
77
+ selected_example = st.selectbox('Search engine', searchers)
78
+
79
+ n_max_doc = st.number_input('Number of web pages retrievad in single search', min_value=1, max_value=50, value=10, step=5)
80
+ st.header('Examples')
81
+ examples = ['AlphaFold', '2024 Hualien City Earthquake', 'Taylor Swift', 'Yoon Seok-youl']
82
+ selected_example = st.selectbox('case', examples)
83
+ status_placeholder = st.empty()
84
+
85
+ mind_map = MindMap(
86
+ retriever=rm,
87
+ gen_concept_lm = lm4outline,
88
+ gen_concept_lm2 = lm4outline,
89
+ search_top_k = n_max_doc,
90
+ depth= MAX_ROUNDS
91
+ )
92
+
93
+ def Think(input_topic):
94
+
95
+ generator = mind_map.build_map(input_topic)
96
+
97
+ st.markdown(f'Performing an in-depth search on the content related to {input_topic}...')
98
+
99
+ for idx, layer in enumerate(generator):
100
+ print(layer)
101
+ print('layer!!!')
102
+ st.markdown(f'Deep Thinking Retrieval at Level {idx + 1}...')
103
+ status_placeholder.text(f"Currently conducting the {idx + 1}th level deep thinking retrieval, estimated to take {(idx+1)*3} minutes.")
104
+ for node in layer:
105
+ category = node.category
106
+
107
+ print(f'category: {category}')
108
+ with st.expander(f'{category}'):
109
+ st.markdown(f'### The concept of {node.category}')
110
+ print(node.concept)
111
+ for concept in node.concept:
112
+ st.markdown(f'* {concept}')
113
+ st.markdown(f'### The web of {node.category}')
114
+ for idx, info in enumerate(node.info):
115
+ st.markdown(f'{idx + 1}. {info["title"]} \n {info["snippets"]}')
116
+
117
+ st.markdown(f'Constructing an index table for the {mind_map.get_web_number()} retrieved web pages...')
118
+ mind_map.prepare_table_for_retrieval()
119
+ return '__finish__', '__finish__'
120
+
121
+ def GenOutline(input_topic):
122
+ status_placeholder.text("The outline writing is in progress and is expected to take 1 minute.")
123
+ ogm = OutlineGenerationModule(lm)
124
+ outline = ogm.generate_outline(topic= input_topic, mindmap = mind_map)
125
+
126
+ return outline
127
+
128
+ def GenArticle(input_topic, outline):
129
+ status_placeholder.text("The article writing is in progress and is expected to take 3 minutes.")
130
+
131
+ article_with_outline = Article.from_outline_str(topic=input_topic, outline_str=outline)
132
+ ag = ArticleGenerationModule(retriever = rm, article_gen_lm = lm, retrieve_top_k = 3, max_thread_num = 10)
133
+ article = ag.generate_article(topic = topic, mindmap = mind_map, article_with_outline = article_with_outline)
134
+ ap = ArticlePolishingModule(article_gen_lm = lm, article_polish_lm = lm)
135
+ article = ap.polish_article(topic = topic, draft_article = article)
136
+ return article.to_string()
137
+
138
+ with st.form('my_form'):
139
+ topic = st.text_input('Please enter the topic you are interested in.', value=selected_example, placeholder='Please enter the topic you are interested in.')
140
+ submit_button = st.form_submit_button('Generate!')
141
+
142
+ if submit_button:
143
+ if topic:
144
+ st.markdown('### Thought process')
145
+ summary, news_timeline = Think(topic)
146
+ st.session_state.summary = summary
147
+ st.session_state.news_timeline = news_timeline
148
+
149
+ st.markdown('### Outline generation')
150
+ with st.expander("Outline generation", expanded=True):
151
+ outline = GenOutline(topic)
152
+ st.text(outline)
153
+
154
+ st.markdown('### article generation')
155
+ with st.expander("article generation", expanded=True):
156
+ article = GenArticle(topic, outline)
157
+ st.markdown(article)
158
+ else:
159
+ st.error('Please enter the subject.')
160
+
161
+
model/paraphrase-MiniLM-L6-v2/.gitattributes ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
2
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.h5 filter=lfs diff=lfs merge=lfs -text
5
+ *.tflite filter=lfs diff=lfs merge=lfs -text
6
+ *.tar.gz filter=lfs diff=lfs merge=lfs -text
7
+ *.ot filter=lfs diff=lfs merge=lfs -text
8
+ *.onnx filter=lfs diff=lfs merge=lfs -text
9
+ *.arrow filter=lfs diff=lfs merge=lfs -text
10
+ *.ftz filter=lfs diff=lfs merge=lfs -text
11
+ *.joblib filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.pb filter=lfs diff=lfs merge=lfs -text
15
+ *.pt filter=lfs diff=lfs merge=lfs -text
16
+ *.pth filter=lfs diff=lfs merge=lfs -text
17
+ model.safetensors filter=lfs diff=lfs merge=lfs -text
model/paraphrase-MiniLM-L6-v2/1_Pooling/config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "word_embedding_dimension": 384,
3
+ "pooling_mode_cls_token": false,
4
+ "pooling_mode_mean_tokens": true,
5
+ "pooling_mode_max_tokens": false,
6
+ "pooling_mode_mean_sqrt_len_tokens": false
7
+ }
model/paraphrase-MiniLM-L6-v2/README.md ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ library_name: sentence-transformers
4
+ tags:
5
+ - sentence-transformers
6
+ - feature-extraction
7
+ - sentence-similarity
8
+ - transformers
9
+ pipeline_tag: sentence-similarity
10
+ ---
11
+
12
+ # sentence-transformers/paraphrase-MiniLM-L6-v2
13
+
14
+ This is a [sentence-transformers](https://www.SBERT.net) model: It maps sentences & paragraphs to a 384 dimensional dense vector space and can be used for tasks like clustering or semantic search.
15
+
16
+
17
+
18
+ ## Usage (Sentence-Transformers)
19
+
20
+ Using this model becomes easy when you have [sentence-transformers](https://www.SBERT.net) installed:
21
+
22
+ ```
23
+ pip install -U sentence-transformers
24
+ ```
25
+
26
+ Then you can use the model like this:
27
+
28
+ ```python
29
+ from sentence_transformers import SentenceTransformer
30
+ sentences = ["This is an example sentence", "Each sentence is converted"]
31
+
32
+ model = SentenceTransformer('sentence-transformers/paraphrase-MiniLM-L6-v2')
33
+ embeddings = model.encode(sentences)
34
+ print(embeddings)
35
+ ```
36
+
37
+
38
+
39
+ ## Usage (HuggingFace Transformers)
40
+ Without [sentence-transformers](https://www.SBERT.net), you can use the model like this: First, you pass your input through the transformer model, then you have to apply the right pooling-operation on-top of the contextualized word embeddings.
41
+
42
+ ```python
43
+ from transformers import AutoTokenizer, AutoModel
44
+ import torch
45
+
46
+
47
+ #Mean Pooling - Take attention mask into account for correct averaging
48
+ def mean_pooling(model_output, attention_mask):
49
+ token_embeddings = model_output[0] #First element of model_output contains all token embeddings
50
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
51
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
52
+
53
+
54
+ # Sentences we want sentence embeddings for
55
+ sentences = ['This is an example sentence', 'Each sentence is converted']
56
+
57
+ # Load model from HuggingFace Hub
58
+ tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/paraphrase-MiniLM-L6-v2')
59
+ model = AutoModel.from_pretrained('sentence-transformers/paraphrase-MiniLM-L6-v2')
60
+
61
+ # Tokenize sentences
62
+ encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
63
+
64
+ # Compute token embeddings
65
+ with torch.no_grad():
66
+ model_output = model(**encoded_input)
67
+
68
+ # Perform pooling. In this case, max pooling.
69
+ sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
70
+
71
+ print("Sentence embeddings:")
72
+ print(sentence_embeddings)
73
+ ```
74
+
75
+
76
+
77
+ ## Evaluation Results
78
+
79
+
80
+
81
+ For an automated evaluation of this model, see the *Sentence Embeddings Benchmark*: [https://seb.sbert.net](https://seb.sbert.net?model_name=sentence-transformers/paraphrase-MiniLM-L6-v2)
82
+
83
+
84
+
85
+ ## Full Model Architecture
86
+ ```
87
+ SentenceTransformer(
88
+ (0): Transformer({'max_seq_length': 128, 'do_lower_case': False}) with Transformer model: BertModel
89
+ (1): Pooling({'word_embedding_dimension': 384, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})
90
+ )
91
+ ```
92
+
93
+ ## Citing & Authors
94
+
95
+ This model was trained by [sentence-transformers](https://www.sbert.net/).
96
+
97
+ If you find this model helpful, feel free to cite our publication [Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks](https://arxiv.org/abs/1908.10084):
98
+ ```bibtex
99
+ @inproceedings{reimers-2019-sentence-bert,
100
+ title = "Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks",
101
+ author = "Reimers, Nils and Gurevych, Iryna",
102
+ booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing",
103
+ month = "11",
104
+ year = "2019",
105
+ publisher = "Association for Computational Linguistics",
106
+ url = "http://arxiv.org/abs/1908.10084",
107
+ }
108
+ ```
model/paraphrase-MiniLM-L6-v2/config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "old_models/paraphrase-MiniLM-L6-v2/0_Transformer",
3
+ "architectures": [
4
+ "BertModel"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "gradient_checkpointing": false,
8
+ "hidden_act": "gelu",
9
+ "hidden_dropout_prob": 0.1,
10
+ "hidden_size": 384,
11
+ "initializer_range": 0.02,
12
+ "intermediate_size": 1536,
13
+ "layer_norm_eps": 1e-12,
14
+ "max_position_embeddings": 512,
15
+ "model_type": "bert",
16
+ "num_attention_heads": 12,
17
+ "num_hidden_layers": 6,
18
+ "pad_token_id": 0,
19
+ "position_embedding_type": "absolute",
20
+ "transformers_version": "4.7.0",
21
+ "type_vocab_size": 2,
22
+ "use_cache": true,
23
+ "vocab_size": 30522
24
+ }
model/paraphrase-MiniLM-L6-v2/config_sentence_transformers.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "__version__": {
3
+ "sentence_transformers": "2.0.0",
4
+ "transformers": "4.7.0",
5
+ "pytorch": "1.9.0+cu102"
6
+ }
7
+ }
model/paraphrase-MiniLM-L6-v2/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2ce4480dc3b2f8edeee50c43765c72768e79fc0113d3f73773dded4887cca298
3
+ size 90868373
model/paraphrase-MiniLM-L6-v2/modules.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "idx": 0,
4
+ "name": "0",
5
+ "path": "",
6
+ "type": "sentence_transformers.models.Transformer"
7
+ },
8
+ {
9
+ "idx": 1,
10
+ "name": "1",
11
+ "path": "1_Pooling",
12
+ "type": "sentence_transformers.models.Pooling"
13
+ }
14
+ ]
model/paraphrase-MiniLM-L6-v2/onnx/model.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:441a5dc61ff3b889892feeb7aa0400518cc9908603209c45861ba3abef3006bc
3
+ size 90405214
model/paraphrase-MiniLM-L6-v2/onnx/model_O1.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e4b50dcc09ca71accf34b7c7d843ad157499bb2e2c7f7a9b9bc1bbb720147ce6
3
+ size 90360328
model/paraphrase-MiniLM-L6-v2/onnx/model_O2.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d99151ccdfb700fb278610e060f3debf615b59da275ba5784385d49c8b8e8e9c
3
+ size 90326566
model/paraphrase-MiniLM-L6-v2/onnx/model_O3.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:053f244528154eef2db50c676536cdf0ab1e9cba20693ad8c9d83cb592126072
3
+ size 90326497
model/paraphrase-MiniLM-L6-v2/onnx/model_O4.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:891a315753191d6dfb32c193615187456e33ee52d6425e0ad8dac2d086350f81
3
+ size 45212349
model/paraphrase-MiniLM-L6-v2/onnx/model_qint8_arm64.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ccc4bb68331e8410226d021ed709d6f2db3b0b25a43504828fa1d54fc6f7b3b3
3
+ size 23026053
model/paraphrase-MiniLM-L6-v2/onnx/model_qint8_avx512.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ccc4bb68331e8410226d021ed709d6f2db3b0b25a43504828fa1d54fc6f7b3b3
3
+ size 23026053
model/paraphrase-MiniLM-L6-v2/onnx/model_qint8_avx512_vnni.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ccc4bb68331e8410226d021ed709d6f2db3b0b25a43504828fa1d54fc6f7b3b3
3
+ size 23026053
model/paraphrase-MiniLM-L6-v2/onnx/model_quint8_avx2.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7d2e9f4180455601b4ebb64602ba667f551c87f16e791550479346851b6e4787
3
+ size 23046789
model/paraphrase-MiniLM-L6-v2/openvino/openvino_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c6005063ac5c88df685065089e887719f43956959a2080c7b9467bc17924645d
3
+ size 90265744
model/paraphrase-MiniLM-L6-v2/openvino/openvino_model.xml ADDED
The diff for this file is too large to render. See raw diff
 
model/paraphrase-MiniLM-L6-v2/openvino/openvino_model_qint8_quantized.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f036c75118e1df8040b4be3d5b7589ae1f1bb0c1f0f5d666b9bd317a2c8014d5
3
+ size 22933664
model/paraphrase-MiniLM-L6-v2/openvino/openvino_model_qint8_quantized.xml ADDED
The diff for this file is too large to render. See raw diff
 
model/paraphrase-MiniLM-L6-v2/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5d716de760acbdc09e79a11e718c5606e0812b6aeb76c6664cba876d174e3ecd
3
+ size 90895153
model/paraphrase-MiniLM-L6-v2/sentence_bert_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "max_seq_length": 128,
3
+ "do_lower_case": false
4
+ }
model/paraphrase-MiniLM-L6-v2/special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"}
model/paraphrase-MiniLM-L6-v2/tf_model.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ee09134d6f68fddf22606d3bc296855df94e527bbfe1555151e4b9613564a218
3
+ size 91005696
model/paraphrase-MiniLM-L6-v2/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
model/paraphrase-MiniLM-L6-v2/tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"do_lower_case": true, "unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]", "tokenize_chinese_chars": true, "strip_accents": null, "name_or_path": "nreimers/MiniLM-L6-H384-uncased", "do_basic_tokenize": true, "never_split": null, "model_max_length": 512}
model/paraphrase-MiniLM-L6-v2/vocab.txt ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dspy_ai==2.4.9
2
+ wikipedia==1.4.0
3
+ sentence-transformers
4
+ toml
5
+ langchain-text-splitters
6
+ trafilatura
7
+ langchain-huggingface
8
+ qdrant-client
9
+ langchain-qdrant
10
+ numpy==1.26.4
11
+ dashscope
12
+ beautifulsoup4
13
+ streamlit==1.37.1
14
+ python-dotenv
15
+ streamlit-vis-timeline==0.3.0
16
+ tilse
17
+ jsonlines
18
+ rank-bm25
19
+ transformers
20
+ litellm
21
+ lxml
22
+ lxml_html_clean
src/DeepThink/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (223 Bytes). View file
 
src/DeepThink/__pycache__/engine.cpython-311.pyc ADDED
Binary file (18.2 kB). View file
 
src/DeepThink/engine.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ from dataclasses import dataclass, field
5
+ from typing import Union, Literal, Optional
6
+
7
+ import dspy
8
+ from interface import Engine, LMConfigs
9
+ from lm import OpenAIModel
10
+
11
+ class LMConfigs():
12
+ """Configurations for LLM used in different parts of STORM.
13
+
14
+ Given that different parts in STORM framework have different complexity, we use different LLM configurations
15
+ to achieve a balance between quality and efficiency. If no specific configuration is provided, we use the default
16
+ setup in the paper.
17
+ """
18
+
19
+ def __init__(self):
20
+ self.conv_simulator_lm = None # LLM used in conversation simulator except for question asking.
21
+ self.question_asker_lm = None # LLM used in question asking.
22
+ self.outline_gen_lm = None # LLM used in outline generation.
23
+ self.article_gen_lm = None # LLM used in article generation.
24
+ self.article_polish_lm = None # LLM used in article polishing.
25
+
26
+ def set_lm(self, model: Union[dspy.dsp.LM, dspy.dsp.HFModel]):
27
+ self.lm = model
28
+
29
+ def set_conv_simulator_lm(self, model: Union[dspy.dsp.LM, dspy.dsp.HFModel]):
30
+ self.conv_simulator_lm = model
31
+
32
+ def set_question_asker_lm(self, model: Union[dspy.dsp.LM, dspy.dsp.HFModel]):
33
+ self.question_asker_lm = model
34
+
35
+ def set_outline_gen_lm(self, model: Union[dspy.dsp.LM, dspy.dsp.HFModel]):
36
+ self.outline_gen_lm = model
37
+
38
+ def set_article_gen_lm(self, model: Union[dspy.dsp.LM, dspy.dsp.HFModel]):
39
+ self.article_gen_lm = model
40
+
41
+ def set_article_polish_lm(self, model: Union[dspy.dsp.LM, dspy.dsp.HFModel]):
42
+ self.article_polish_lm = model
43
+
44
+
45
+ @dataclass
46
+ class RunnerArguments:
47
+ """Arguments for controlling the STORM Wiki pipeline."""
48
+ output_dir: str = field(
49
+ metadata={"help": "Output directory for the results."},
50
+ )
51
+ max_conv_turn: int = field(
52
+ default=3,
53
+ metadata={"help": "Maximum number of questions in conversational question asking."},
54
+ )
55
+ max_perspective: int = field(
56
+ default=3,
57
+ metadata={"help": "Maximum number of perspectives to consider in perspective-guided question asking."},
58
+ )
59
+ max_search_queries_per_turn: int = field(
60
+ default=3,
61
+ metadata={"help": "Maximum number of search queries to consider in each turn."},
62
+ )
63
+ disable_perspective: bool = field(
64
+ default=False,
65
+ metadata={"help": "If True, disable perspective-guided question asking."},
66
+ )
67
+ search_top_k: int = field(
68
+ default=3,
69
+ metadata={"help": "Top k search results to consider for each search query."},
70
+ )
71
+ retrieve_top_k: int = field(
72
+ default=3,
73
+ metadata={"help": "Top k collected references for each section title."},
74
+ )
75
+ max_thread_num: int = field(
76
+ default=10,
77
+ metadata={"help": "Maximum number of threads to use. "
78
+ "Consider reducing it if keep getting 'Exceed rate limit' error when calling LM API."},
79
+ )
80
+
81
+ class Runner():
82
+ """STORM Wiki pipeline runner."""
83
+
84
+ def __init__(self,
85
+ args: RunnerArguments,
86
+ lm_configs: LMConfigs,
87
+ rm):
88
+ super().__init__(lm_configs=lm_configs)
89
+ self.args = args
90
+ self.lm_configs = lm_configs
91
+
92
+ self.retriever = StormRetriever(rm=rm, k=self.args.retrieve_top_k)
93
+ storm_persona_generator = StormPersonaGenerator(self.lm_configs.question_asker_lm)
94
+ self.storm_knowledge_curation_module = StormKnowledgeCurationModule(
95
+ retriever=self.retriever,
96
+ persona_generator=storm_persona_generator,
97
+ conv_simulator_lm=self.lm_configs.conv_simulator_lm,
98
+ question_asker_lm=self.lm_configs.question_asker_lm,
99
+ max_search_queries_per_turn=self.args.max_search_queries_per_turn,
100
+ search_top_k=self.args.search_top_k,
101
+ max_conv_turn=self.args.max_conv_turn,
102
+ max_thread_num=self.args.max_thread_num
103
+ )
104
+
105
+ self.storm_outline_generation_module = StormOutlineGenerationModule(
106
+ outline_gen_lm=self.lm_configs.outline_gen_lm
107
+ )
108
+
109
+ self.storm_article_generation = StormArticleGenerationModule(
110
+ article_gen_lm=self.lm_configs.article_gen_lm,
111
+ retrieve_top_k=self.args.retrieve_top_k,
112
+ max_thread_num=self.args.max_thread_num,
113
+ retriever =self.retriever
114
+ )
115
+
116
+ self.storm_article_polishing_module = StormArticlePolishingModule(
117
+ article_gen_lm=self.lm_configs.article_gen_lm,
118
+ article_polish_lm=self.lm_configs.article_polish_lm
119
+ )
120
+
121
+ self.lm_configs.init_check()
122
+ self.apply_decorators()
123
+
124
+ def run_knowledge_curation_module(self,
125
+ ground_truth_url: str = "None",
126
+ ) -> StormInformationTable:
127
+ #第一次进入的地方,此处还是原topic,information_table既有所���的conversation对话又有所有的url和snippet的对应dict
128
+ information_table, conversation_log = self.storm_knowledge_curation_module.research(
129
+ topic=self.topic,
130
+ ground_truth_url=ground_truth_url,
131
+ callback_handler=callback_handler,
132
+ max_perspective=self.args.max_perspective,
133
+ disable_perspective=False,
134
+ return_conversation_log=True
135
+ )
136
+
137
+ FileIOHelper.dump_json(conversation_log, os.path.join(self.article_output_dir, 'conversation_log.json'))
138
+ information_table.dump_url_to_info(os.path.join(self.article_output_dir, 'raw_search_results.json'))
139
+
140
+ return information_table
141
+
142
+ def run_outline_generation_module(self,
143
+ information_table: StormInformationTable,
144
+ callback_handler: BaseCallbackHandler = None) -> StormArticle:
145
+
146
+ outline, draft_outline = self.storm_outline_generation_module.generate_outline(
147
+ topic=self.topic,
148
+ information_table=information_table,
149
+ return_draft_outline=True,
150
+ callback_handler=callback_handler
151
+ )
152
+ outline.dump_outline_to_file(os.path.join(self.article_output_dir, 'storm_gen_outline.txt'))
153
+ draft_outline.dump_outline_to_file(os.path.join(self.article_output_dir, "direct_gen_outline.txt"))
154
+ return outline
155
+
156
+ def run_article_generation_module(self,
157
+ outline: StormArticle,
158
+ information_table=StormInformationTable,
159
+ callback_handler: BaseCallbackHandler = None) -> StormArticle:
160
+
161
+ draft_article = self.storm_article_generation.generate_article(
162
+ topic=self.topic,
163
+ information_table=information_table,
164
+ article_with_outline=outline,
165
+ callback_handler=callback_handler
166
+ )
167
+ draft_article.dump_article_as_plain_text(os.path.join(self.article_output_dir, 'storm_gen_article.txt'))
168
+ draft_article.dump_reference_to_file(os.path.join(self.article_output_dir, 'url_to_info.json'))
169
+ return draft_article
170
+
171
+ def run_article_polishing_module(self,
172
+ draft_article: StormArticle,
173
+ remove_duplicate: bool = False) -> StormArticle:
174
+
175
+ polished_article = self.storm_article_polishing_module.polish_article(
176
+ topic=self.topic,
177
+ draft_article=draft_article,
178
+ remove_duplicate=remove_duplicate
179
+ )
180
+ FileIOHelper.write_str(polished_article.to_string(),
181
+ os.path.join(self.article_output_dir, 'storm_gen_article_polished.txt'))
182
+ return polished_article
183
+
184
+ def post_run(self):
185
+ """
186
+ Post-run operations, including:
187
+ 1. Dumping the run configuration.
188
+ 2. Dumping the LLM call history.
189
+ """
190
+ config_log = self.lm_configs.log()
191
+ FileIOHelper.dump_json(config_log, os.path.join(self.article_output_dir, 'run_config.json'))
192
+
193
+ llm_call_history = self.lm_configs.collect_and_reset_lm_history()
194
+ with open(os.path.join(self.article_output_dir, 'llm_call_history.jsonl'), 'w') as f:
195
+ for call in llm_call_history:
196
+ if 'kwargs' in call:
197
+ call.pop('kwargs') # All kwargs are dumped together to run_config.json.
198
+ f.write(json.dumps(call) + '\n')
199
+
200
+ def _load_information_table_from_local_fs(self, information_table_local_path):
201
+ assert os.path.exists(information_table_local_path), makeStringRed(f"{information_table_local_path} not exists. Please set --do-research argument to prepare the conversation_log.json for this topic.")
202
+ return StormInformationTable.from_conversation_log_file(information_table_local_path)
203
+
204
+ def _load_outline_from_local_fs(self, topic, outline_local_path):
205
+ assert os.path.exists(outline_local_path), makeStringRed(f"{outline_local_path} not exists. Please set --do-generate-outline argument to prepare the storm_gen_outline.txt for this topic.")
206
+ return StormArticle.from_outline_file(topic=topic, file_path=outline_local_path)
207
+
208
+ def _load_draft_article_from_local_fs(self, topic, draft_article_path, url_to_info_path):
209
+ assert os.path.exists(draft_article_path), makeStringRed(f"{draft_article_path} not exists. Please set --do-generate-article argument to prepare the storm_gen_article.txt for this topic.")
210
+ assert os.path.exists(url_to_info_path), makeStringRed(f"{url_to_info_path} not exists. Please set --do-generate-article argument to prepare the url_to_info.json for this topic.")
211
+ article_text = FileIOHelper.load_str(draft_article_path)
212
+ references = FileIOHelper.load_json(url_to_info_path)
213
+ return StormArticle.from_string(topic_name=topic, article_text=article_text, references=references)
214
+
215
+ def run(self,
216
+ topic: str,
217
+ ground_truth_url: str = '',
218
+ do_research: bool = True,
219
+ do_generate_outline: bool = True,
220
+ do_generate_article: bool = True,
221
+ do_polish_article: bool = True,
222
+ remove_duplicate: bool = False,
223
+ callback_handler: BaseCallbackHandler = BaseCallbackHandler()):
224
+ """
225
+ Run the STORM pipeline.
226
+
227
+ Args:
228
+ topic: The topic to research.
229
+ ground_truth_url: A ground truth URL including a curated article about the topic. The URL will be excluded.
230
+ do_research: If True, research the topic through information-seeking conversation;
231
+ if False, expect conversation_log.json and raw_search_results.json to exist in the output directory.
232
+ do_generate_outline: If True, generate an outline for the topic;
233
+ if False, expect storm_gen_outline.txt to exist in the output directory.
234
+ do_generate_article: If True, generate a curated article for the topic;
235
+ if False, expect storm_gen_article.txt to exist in the output directory.
236
+ do_polish_article: If True, polish the article by adding a summarization section and (optionally) removing
237
+ duplicated content.
238
+ remove_duplicate: If True, remove duplicated content.
239
+ callback_handler: A callback handler to handle the intermediate results.
240
+ """
241
+ assert do_research or do_generate_outline or do_generate_article or do_polish_article, \
242
+ makeStringRed("No action is specified. Please set at least one of --do-research, --do-generate-outline, --do-generate-article, --do-polish-article")
243
+
244
+ self.topic = topic
245
+ self.article_dir_name = topic.replace(' ', '_').replace('/', '_')
246
+ self.article_output_dir = os.path.join(self.args.output_dir, self.article_dir_name)
247
+ os.makedirs(self.article_output_dir, exist_ok=True)
248
+
249
+ # research module,先自己生成一些链接得到一些url,然后读取url生成一些不同的人格,然后对不同的人格进行对话得到有用信息
250
+ information_table: StormInformationTable = None
251
+ if do_research:
252
+ information_table = self.run_knowledge_curation_module(ground_truth_url=ground_truth_url,
253
+ callback_handler=callback_handler)
254
+
255
+ # outline generation module,这地方就是生成一些outline,可以选择根据前面的conversation进行生成outline会更详细一些
256
+ outline: StormArticle = None
257
+ if do_generate_outline:
258
+ # load information table if it's not initialized
259
+ if information_table is None:
260
+ information_table = self._load_information_table_from_local_fs(os.path.join(self.article_output_dir, '.json'))
261
+ outline = self.run_outline_generation_module(information_table=information_table,
262
+ callback_handler=callback_handler)
263
+
264
+
265
+ # article generation module
266
+ draft_article: StormArticle = None
267
+ if do_generate_article:
268
+ if information_table is None:
269
+ information_table = self._load_information_table_from_local_fs(os.path.join(self.article_output_dir, 'conversation_log.json'))
270
+ if outline is None:
271
+ outline = self._load_outline_from_local_fs(topic=topic, outline_local_path=os.path.join(self.article_output_dir, 'storm_gen_outline.txt'))
272
+
273
+ draft_article = self.run_article_generation_module(outline=outline,
274
+ information_table=information_table,
275
+ callback_handler=callback_handler)
276
+
277
+ # article polishing module
278
+ if do_polish_article:
279
+ if draft_article is None:
280
+ draft_article_path = os.path.join(self.article_output_dir, 'storm_gen_article.txt')
281
+ url_to_info_path = os.path.join(self.article_output_dir, 'url_to_info.json')
282
+ draft_article = self._load_draft_article_from_local_fs(topic=topic, draft_article_path=draft_article_path, url_to_info_path=url_to_info_path)
283
+ self.run_article_polishing_module(draft_article=draft_article, remove_duplicate=remove_duplicate)
284
+
285
+ post_polish(self.article_output_dir)
src/DeepThink/modules/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (242 Bytes). View file
 
src/DeepThink/modules/__pycache__/article_generation.cpython-310.pyc ADDED
Binary file (6.43 kB). View file
 
src/DeepThink/modules/__pycache__/article_generation.cpython-311.pyc ADDED
Binary file (10.5 kB). View file
 
src/DeepThink/modules/__pycache__/article_polish.cpython-310.pyc ADDED
Binary file (3.4 kB). View file
 
src/DeepThink/modules/__pycache__/article_polish.cpython-311.pyc ADDED
Binary file (5.13 kB). View file
 
src/DeepThink/modules/__pycache__/interface.cpython-310.pyc ADDED
Binary file (17.2 kB). View file
 
src/DeepThink/modules/__pycache__/interface.cpython-311.pyc ADDED
Binary file (24.2 kB). View file
 
src/DeepThink/modules/__pycache__/mindmap.cpython-310.pyc ADDED
Binary file (14.1 kB). View file
 
src/DeepThink/modules/__pycache__/mindmap.cpython-311.pyc ADDED
Binary file (25.4 kB). View file
 
src/DeepThink/modules/__pycache__/outline_generation.cpython-310.pyc ADDED
Binary file (4.84 kB). View file
 
src/DeepThink/modules/__pycache__/outline_generation.cpython-311.pyc ADDED
Binary file (7.47 kB). View file
 
src/DeepThink/modules/__pycache__/retriever.cpython-311.pyc ADDED
Binary file (3.81 kB). View file
 
src/DeepThink/modules/__pycache__/storm_dataclass.cpython-310.pyc ADDED
Binary file (12.6 kB). View file
 
src/DeepThink/modules/__pycache__/storm_dataclass.cpython-311.pyc ADDED
Binary file (21.6 kB). View file
 
src/DeepThink/modules/__pycache__/utils.cpython-310.pyc ADDED
Binary file (12.5 kB). View file
 
src/DeepThink/modules/__pycache__/utils.cpython-311.pyc ADDED
Binary file (20.4 kB). View file
 
src/DeepThink/modules/article_generation.py ADDED
@@ -0,0 +1,523 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import concurrent.futures
2
+ import copy
3
+ import logging
4
+ from concurrent.futures import as_completed
5
+ from typing import List, Union
6
+ import random
7
+ import dspy
8
+ import sys
9
+
10
+
11
+ import concurrent.futures
12
+ import json
13
+ import os
14
+ import pickle
15
+ import re
16
+ import sys
17
+ from typing import List, Dict
18
+
19
+ import httpx
20
+ import toml
21
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
22
+ from trafilatura import extract
23
+
24
+
25
+ class ArticleTextProcessing:
26
+ @staticmethod
27
+ def limit_word_count_preserve_newline(input_string, max_word_count):
28
+ """
29
+ Limit the word count of an input string to a specified maximum, while preserving the integrity of complete lines.
30
+
31
+ The function truncates the input string at the nearest word that does not exceed the maximum word count,
32
+ ensuring that no partial lines are included in the output. Words are defined as text separated by spaces,
33
+ and lines are defined as text separated by newline characters.
34
+
35
+ Args:
36
+ input_string (str): The string to be truncated. This string may contain multiple lines.
37
+ max_word_count (int): The maximum number of words allowed in the truncated string.
38
+
39
+ Returns:
40
+ str: The truncated string with word count limited to `max_word_count`, preserving complete lines.
41
+ """
42
+
43
+ word_count = 0
44
+ limited_string = ''
45
+
46
+ for word in input_string.split('\n'):
47
+ line_words = word.split()
48
+ for lw in line_words:
49
+ if word_count < max_word_count:
50
+ limited_string += lw + ' '
51
+ word_count += 1
52
+ else:
53
+ break
54
+ if word_count >= max_word_count:
55
+ break
56
+ limited_string = limited_string.strip() + '\n'
57
+
58
+ return limited_string.strip()
59
+
60
+ @staticmethod
61
+ def remove_citations(s):
62
+ """
63
+ Removes all citations from a given string. Citations are assumed to be in the format
64
+ of numbers enclosed in square brackets, such as [1], [2], or [1, 2], etc. This function searches
65
+ for all occurrences of such patterns and removes them, returning the cleaned string.
66
+
67
+ Args:
68
+ s (str): The string from which citations are to be removed.
69
+
70
+ Returns:
71
+ str: The string with all citation patterns removed.
72
+ """
73
+
74
+ return re.sub(r'\[\d+(?:,\s*\d+)*\]', '', s)
75
+
76
+ @staticmethod
77
+ def get_first_section_dict_and_list(s):
78
+ """
79
+ """
80
+ text = s
81
+ sections = text.strip().split('\n# ')
82
+ titles = []
83
+ content_dict = {}
84
+
85
+ for section in sections:
86
+ if section:
87
+ lines = section.split('\n', 1)
88
+ title = lines[0].strip()
89
+ content = lines[1].strip() if len(lines) > 1 else ""
90
+
91
+ titles.append(title)
92
+ content_dict[title] = content
93
+ return content_dict, titles
94
+
95
+ @staticmethod
96
+ def parse_citation_indices(s):
97
+ """
98
+ Extracts citation indexes from the provided content string and returns them as a list of integers.
99
+
100
+ Args:
101
+ content (str): The content string containing citations in the format [number].
102
+
103
+ Returns:
104
+ List[int]: A list of unique citation indexes extracted from the content, in the order they appear.
105
+ """
106
+ matches = re.findall(r'\[\d+\]', s)
107
+ return [int(index[1:-1]) for index in matches]
108
+
109
+ @staticmethod
110
+ def remove_uncompleted_sentences_with_citations(text):
111
+ """
112
+ Removes uncompleted sentences and standalone citations from the input text. Sentences are identified
113
+ by their ending punctuation (.!?), optionally followed by a citation in square brackets (e.g., "[1]").
114
+ Grouped citations (e.g., "[1, 2]") are split into individual ones (e.g., "[1] [2]"). Only text up to
115
+ and including the last complete sentence and its citation is retained.
116
+
117
+ Args:
118
+ text (str): The input text from which uncompleted sentences and their citations are to be removed.
119
+
120
+ Returns:
121
+ str: The processed string with uncompleted sentences and standalone citations removed, leaving only
122
+ complete sentences and their associated citations if present.
123
+ """
124
+
125
+ # Convert citations like [1, 2, 3] to [1][2][3].
126
+ def replace_with_individual_brackets(match):
127
+ numbers = match.group(1).split(', ')
128
+ return ' '.join(f'[{n}]' for n in numbers)
129
+
130
+ # Deduplicate and sort individual groups of citations.
131
+ def deduplicate_group(match):
132
+ citations = match.group(0)
133
+ unique_citations = list(set(re.findall(r'\[\d+\]', citations)))
134
+ sorted_citations = sorted(unique_citations, key=lambda x: int(x.strip('[]')))
135
+ # Return the sorted unique citations as a string
136
+ return ''.join(sorted_citations)
137
+
138
+ text = re.sub(r'\[([0-9, ]+)\]', replace_with_individual_brackets, text)
139
+ text = re.sub(r'(\[\d+\])+', deduplicate_group, text)
140
+
141
+ # Deprecated: Remove sentence without proper ending punctuation and citations.
142
+ # Split the text into sentences (including citations).
143
+ # sentences_with_trailing = re.findall(r'([^.!?]*[.!?].*?)(?=[^.!?]*[.!?]|$)', text)
144
+
145
+ # Filter sentences to ensure they end with a punctuation mark and properly formatted citations
146
+ # complete_sentences = []
147
+ # for sentence in sentences_with_trailing:
148
+ # # Check if the sentence ends with properly formatted citations
149
+ # if re.search(r'[.!?]( \[\d+\])*$|^[^.!?]*[.!?]$', sentence.strip()):
150
+ # complete_sentences.append(sentence.strip())
151
+
152
+ # combined_sentences = ' '.join(complete_sentences)
153
+
154
+ # Check for and append any complete citations that follow the last sentence
155
+ # trailing_citations = re.findall(r'(\[\d+\]) ', text[text.rfind(combined_sentences) + len(combined_sentences):])
156
+ # if trailing_citations:
157
+ # combined_sentences += ' '.join(trailing_citations)
158
+
159
+ # Regex pattern to match sentence endings, including optional citation markers.
160
+ eos_pattern = r'([.!?])\s*(\[\d+\])?\s*'
161
+ matches = list(re.finditer(eos_pattern, text))
162
+ if matches:
163
+ last_match = matches[-1]
164
+ text = text[:last_match.end()].strip()
165
+
166
+ return text
167
+
168
+ @staticmethod
169
+ def clean_up_citation(conv):
170
+ for turn in conv.dlg_history:
171
+ turn.agent_utterance = turn.agent_utterance[:turn.agent_utterance.find('References:')]
172
+ turn.agent_utterance = turn.agent_utterance[:turn.agent_utterance.find('Sources:')]
173
+ turn.agent_utterance = turn.agent_utterance.replace('Answer:', '').strip()
174
+ try:
175
+ max_ref_num = max([int(x) for x in re.findall(r'\[(\d+)\]', turn.agent_utterance)])
176
+ except Exception as e:
177
+ max_ref_num = 0
178
+ if max_ref_num > len(turn.search_results):
179
+ for i in range(len(turn.search_results), max_ref_num + 1):
180
+ turn.agent_utterance = turn.agent_utterance.replace(f'[{i}]', '')
181
+ turn.agent_utterance = ArticleTextProcessing.remove_uncompleted_sentences_with_citations(
182
+ turn.agent_utterance)
183
+
184
+ return conv
185
+
186
+ @staticmethod
187
+ def clean_up_outline(outline, topic=""):
188
+ output_lines = []
189
+ current_level = 0 # To track the current section level
190
+
191
+ for line in outline.split('\n'):
192
+ stripped_line = line.strip()
193
+
194
+ if topic != "" and f"# {topic.lower()}" in stripped_line.lower():
195
+ output_lines = []
196
+
197
+ # Check if the line is a section header
198
+ if stripped_line.startswith('#') and stripped_line != '#':
199
+ current_level = stripped_line.count('#')
200
+ output_lines.append(stripped_line)
201
+ # Check if the line is a bullet point
202
+ # elif stripped_line.startswith('-'):
203
+ # subsection_header = '#' * (current_level + 1) + ' ' + stripped_line[1:].strip()
204
+ # output_lines.append(subsection_header)
205
+ # Preserve lines with @
206
+ elif stripped_line.startswith('@'):
207
+ output_lines.append(stripped_line)
208
+
209
+ outline = '\n'.join(output_lines)
210
+
211
+ # Remove references.
212
+ outline = re.sub(r"#[#]? See also.*?(?=##|$)", '', outline, flags=re.DOTALL)
213
+ outline = re.sub(r"#[#]? See Also.*?(?=##|$)", '', outline, flags=re.DOTALL)
214
+ outline = re.sub(r"#[#]? Notes.*?(?=##|$)", '', outline, flags=re.DOTALL)
215
+ outline = re.sub(r"#[#]? References.*?(?=##|$)", '', outline, flags=re.DOTALL)
216
+ outline = re.sub(r"#[#]? External links.*?(?=##|$)", '', outline, flags=re.DOTALL)
217
+ outline = re.sub(r"#[#]? External Links.*?(?=##|$)", '', outline, flags=re.DOTALL)
218
+ outline = re.sub(r"#[#]? Bibliography.*?(?=##|$)", '', outline, flags=re.DOTALL)
219
+ outline = re.sub(r"#[#]? Further reading*?(?=##|$)", '', outline, flags=re.DOTALL)
220
+ outline = re.sub(r"#[#]? Further Reading*?(?=##|$)", '', outline, flags=re.DOTALL)
221
+ outline = re.sub(r"#[#]? Summary.*?(?=##|$)", '', outline, flags=re.DOTALL)
222
+ outline = re.sub(r"#[#]? Appendices.*?(?=##|$)", '', outline, flags=re.DOTALL)
223
+ outline = re.sub(r"#[#]? Appendix.*?(?=##|$)", '', outline, flags=re.DOTALL)
224
+
225
+ return outline
226
+
227
+
228
+ @staticmethod
229
+ def clean_up_section(text):
230
+ """Clean up a section:
231
+ 1. Remove uncompleted sentences (usually due to output token limitation).
232
+ 2. Deduplicate individual groups of citations.
233
+ 3. Remove unnecessary summary."""
234
+
235
+ paragraphs = text.split('\n')
236
+ output_paragraphs = []
237
+ summary_sec_flag = False
238
+ for p in paragraphs:
239
+ p = p.strip()
240
+ if len(p) == 0:
241
+ continue
242
+ if not p.startswith('#'):
243
+ p = ArticleTextProcessing.remove_uncompleted_sentences_with_citations(p)
244
+ if summary_sec_flag:
245
+ if p.startswith('#'):
246
+ summary_sec_flag = False
247
+ else:
248
+ continue
249
+ if p.startswith('Overall') or p.startswith('In summary') or p.startswith('In conclusion'):
250
+ continue
251
+ if "# Summary" in p or '# Conclusion' in p:
252
+ summary_sec_flag = True
253
+ continue
254
+ output_paragraphs.append(p)
255
+
256
+ return '\n\n'.join(output_paragraphs) # Join with '\n\n' for markdown format.
257
+
258
+ @staticmethod
259
+ def update_citation_index(s, citation_map):
260
+ """Update citation index in the string based on the citation map."""
261
+ for original_citation in citation_map:
262
+ s = s.replace(f"[{original_citation}]", f"__PLACEHOLDER_{original_citation}__")
263
+ for original_citation, unify_citation in citation_map.items():
264
+ s = s.replace(f"__PLACEHOLDER_{original_citation}__", f"[{unify_citation}]")
265
+
266
+ return s
267
+
268
+ @staticmethod
269
+ def parse_article_into_dict(input_string):
270
+ """
271
+ Parses a structured text into a nested dictionary. The structure of the text
272
+ is defined by markdown-like headers (using '#' symbols) to denote sections
273
+ and subsections. Each section can contain content and further nested subsections.
274
+
275
+ The resulting dictionary captures the hierarchical structure of sections, where
276
+ each section is represented as a key (the section's title) mapping to a value
277
+ that is another dictionary. This dictionary contains two keys:
278
+ - 'content': content of the section
279
+ - 'subsections': a list of dictionaries, each representing a nested subsection
280
+ following the same structure.
281
+
282
+ Args:
283
+ input_string (str): A string containing the structured text to parse.
284
+
285
+ Returns:
286
+ A dictionary representing contains the section title as the key, and another dictionary
287
+ as the value, which includes the 'content' and 'subsections' keys as described above.
288
+ """
289
+ lines = input_string.split('\n')
290
+ lines = [line for line in lines if line.strip()]
291
+ root = {'content': '', 'subsections': {}}
292
+ current_path = [(root, -1)] # (current_dict, level)
293
+
294
+ for line in lines:
295
+ if line.startswith('#'):
296
+ level = line.count('#')
297
+ title = line.strip('# ').strip()
298
+ new_section = {'content': '', 'subsections': {}}
299
+
300
+ # Pop from stack until find the parent level
301
+ while current_path and current_path[-1][1] >= level:
302
+ current_path.pop()
303
+
304
+ # Append new section to the nearest upper level's subsections
305
+ current_path[-1][0]['subsections'][title] = new_section
306
+ current_path.append((new_section, level))
307
+ else:
308
+ current_path[-1][0]['content'] += line + '\n'
309
+
310
+ return root['subsections']
311
+
312
+
313
+ class FileIOHelper:
314
+ @staticmethod
315
+ def dump_json(obj, file_name, encoding="utf-8"):
316
+ with open(file_name, 'w', encoding=encoding) as fw:
317
+ json.dump(obj, fw, default=FileIOHelper.handle_non_serializable, ensure_ascii=False)
318
+
319
+ @staticmethod
320
+ def handle_non_serializable(obj):
321
+ return "non-serializable contents" # mark the non-serializable part
322
+
323
+ @staticmethod
324
+ def load_json(file_name, encoding="utf-8"):
325
+ with open(file_name, 'r', encoding=encoding) as fr:
326
+ return json.load(fr)
327
+
328
+ @staticmethod
329
+ def write_str(s, path):
330
+ with open(path, 'w') as f:
331
+ f.write(s)
332
+
333
+ @staticmethod
334
+ def load_str(path):
335
+ with open(path, 'r') as f:
336
+ return '\n'.join(f.readlines())
337
+
338
+ @staticmethod
339
+ def dump_pickle(obj, path):
340
+ with open(path, 'wb') as f:
341
+ pickle.dump(obj, f)
342
+
343
+ @staticmethod
344
+ def load_pickle(path):
345
+ with open(path, 'rb') as f:
346
+ return pickle.load(f)
347
+
348
+
349
+ class ArticleGenerationModule():
350
+ """
351
+ The interface for article generation stage. Given topic, collected information from
352
+ knowledge curation stage, generated outline from outline generation stage,
353
+ """
354
+
355
+ def __init__(self,
356
+ retriever,
357
+ article_gen_lm=Union[dspy.dsp.LM, dspy.dsp.HFModel],
358
+ retrieve_top_k: int = 10,
359
+ max_thread_num: int = 10,
360
+ ):
361
+ super().__init__()
362
+ self.retrieve_top_k = retrieve_top_k
363
+ self.article_gen_lm = article_gen_lm
364
+ self.max_thread_num = max_thread_num
365
+ self.retriever = retriever
366
+ self.section_gen = ConvToSection(engine=self.article_gen_lm)
367
+
368
+ def generate_section(self, topic, section_name, mindmap, section_query, section_outline):
369
+ collected_info = mindmap.retrieve_information(queries=section_query,
370
+ search_top_k=self.retrieve_top_k)
371
+ output = self.section_gen(
372
+ topic=topic,
373
+ outline=section_outline,
374
+ section=section_name,
375
+ collected_info=collected_info,
376
+ )
377
+
378
+ return {"section_name": section_name, "section_content": output.section, "collected_info": collected_info}
379
+
380
+ def generate_article(self,
381
+ topic: str,
382
+ mindmap,
383
+ article_with_outline,
384
+ ):
385
+ """
386
+ Generate article for the topic based on the information table and article outline.
387
+ """
388
+ mindmap.prepare_table_for_retrieval()
389
+
390
+ sections_to_write = article_with_outline.get_first_level_section_names()
391
+ section_output_dict_collection = []
392
+
393
+ with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_thread_num) as executor:
394
+ future_to_sec_title = {}
395
+ for section_title in sections_to_write:
396
+ section_query = article_with_outline.get_outline_as_list(
397
+ root_section_name=section_title, add_hashtags=False
398
+ )
399
+ queries_with_hashtags = article_with_outline.get_outline_as_list(
400
+ root_section_name=section_title, add_hashtags=True
401
+ )
402
+ section_outline = "\n".join(queries_with_hashtags)
403
+
404
+ future_to_sec_title[
405
+ executor.submit(self.generate_section,
406
+ topic, section_title, mindmap, section_query,section_outline)
407
+ ] = section_title
408
+
409
+ for future in concurrent.futures.as_completed(future_to_sec_title):
410
+ section_output_dict_collection.append(future.result())
411
+
412
+ article = copy.deepcopy(article_with_outline)
413
+ for section_output_dict in section_output_dict_collection:
414
+ article.update_section(parent_section_name=topic,
415
+ current_section_content=section_output_dict["section_content"],
416
+ current_section_info_list=section_output_dict["collected_info"],
417
+ )
418
+
419
+ article.post_processing()
420
+
421
+
422
+
423
+ return article
424
+
425
+ class ConvToSection(dspy.Module):
426
+ """Use the information collected from the information-seeking conversation to write a section."""
427
+ #给你传入的都是所有的section的对应的url,但是这个地方我们的目标是指根据一个来生成,这个地方需要完善,因为他的outline没有用到
428
+ def __init__(self, engine: Union[dspy.dsp.LM, dspy.dsp.HFModel]):
429
+ super().__init__()
430
+ self.write_section = dspy.Predict(WriteSection)
431
+ self.engine = engine
432
+
433
+ def forward(self, topic: str, outline:str, section: str, collected_info: List):
434
+ all_info = ''
435
+ for idx, info in enumerate(collected_info):
436
+ all_info += f'[{idx + 1}]\n' + '\n'.join(info['snippets'])
437
+ all_info += '\n\n'
438
+
439
+ all_info = ArticleTextProcessing.limit_word_count_preserve_newline(all_info, 1500)
440
+
441
+ with dspy.settings.context(lm=self.engine):
442
+ section = ArticleTextProcessing.clean_up_section(
443
+ self.write_section(topic=topic, info=info, section=section).output)
444
+
445
+ section = section.replace('\[','[').replace('\]',']')
446
+ return dspy.Prediction(section=section)
447
+
448
+
449
+
450
+ class WriteSection(dspy.Signature):
451
+ """Write a Wikipedia section based on the collected information.
452
+
453
+ Here is the format of your writing:
454
+ 1. Use "#" Title" to indicate section title, "##" Title" to indicate subsection title, "###" Title" to indicate subsubsection title, and so on.
455
+ 2. Use [1], [2], ..., [n] in line (for example, "The capital of the United States is Washington, D.C.[1][3]."). You DO NOT need to include a References or Sources section to list the sources at the end.
456
+ 3. The language style should resemble that of Wikipedia: concise yet informative, formal yet accessible.
457
+ """
458
+ # """
459
+ # Write a detailed, Wikipedia-style report section based on the collected information.
460
+
461
+ # Here is the format of your writing:
462
+ # 1. Use "#" Title" to indicate section title, "##" Title" to indicate subsection title, "###" Title" to indicate subsubsection title, and so on.
463
+ # 2. Use [1], [2], ..., [n] in line (for example, "The capital of the United States is Washington, D.C.[1][3]."). You DO NOT need to include a References or Sources section to list the sources at the end.
464
+ # 3. The language style should resemble that of Wikipedia: concise yet informative, formal yet accessible.
465
+ # """
466
+
467
+ info = dspy.InputField(prefix="The Collected information:\n", format=str)
468
+ topic = dspy.InputField(prefix="The topic of the page: ", format=str)
469
+ section = dspy.InputField(prefix="The section you need to write: ", format=str)
470
+ output = dspy.OutputField(
471
+ prefix="Write the section with proper inline citations (Start your writing with # section title. Don't include the page title or try to write other sections):\n",
472
+ format=str)
473
+
474
+
475
+
476
+ if __name__ == "__main__":
477
+ import sys
478
+ from mindmap import MindMap
479
+ from outline_generation import OutlineGenerationModule
480
+ sys.path.append('/mnt/nas-alinlp/xizekun/project/DeepThink/src')
481
+ from storm_dataclass import Article
482
+
483
+ from lm import OpenAIModel, OpenAIModel_New
484
+ from rm import BingSearch, BingSearchAli
485
+ from utils import load_api_key
486
+ import os
487
+ load_api_key(toml_file_path='/mnt/nas-alinlp/xizekun/project/DeepThink/secrets.toml')
488
+ openai_kwargs = {
489
+ 'api_key': os.getenv("OPENAI_API_KEY"),
490
+ 'api_provider': os.getenv('OPENAI_API_TYPE'),
491
+ 'temperature': 1.0,
492
+ 'top_p': 0.9,
493
+ 'api_base': os.getenv('AZURE_API_BASE'),
494
+ 'api_version': os.getenv('AZURE_API_VERSION'),
495
+ }
496
+
497
+ lm = OpenAIModel(model='gpt-4-1106-preview', max_tokens=5000, **openai_kwargs)
498
+ rm = BingSearchAli(ydc_api_key=os.getenv('BING_SEARCH_ALI_API_KEY'), k=3)
499
+
500
+ retriever = rm
501
+ gen_concept_lm = lm
502
+
503
+ mind_map = MindMap(
504
+ retriever=retriever,
505
+ gen_concept_lm=lm,
506
+ search_top_k=3,
507
+ deepth = 3
508
+ )
509
+ a = mind_map.load_map('/mnt/nas-alinlp/xizekun/project/DeepThink/src/DeepThink/modules/Taylor.json')
510
+ ag = ArticleGenerationModule(
511
+ retriever = retriever,
512
+ article_gen_lm = lm,
513
+ retrieve_top_k = 5,
514
+ max_thread_num = 10)
515
+
516
+ module = OutlineGenerationModule(lm)
517
+ outline = module.generate_outline(topic= 'Taylor Hawkins',mindmap = mind_map)
518
+ print(outline)
519
+ print('~~~~~~')
520
+
521
+ article_with_outline = Article.from_outline_str(topic='Taylor Hawkins', outline_str=outline)
522
+ a = ag.generate_article(topic = 'Taylor Hawkins', mindmap = mind_map, article_with_outline = article_with_outline)
523
+ print(a.to_string())
src/DeepThink/modules/article_polish.py ADDED
@@ -0,0 +1,417 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from typing import Union
3
+
4
+ import dspy
5
+ # from storm_wiki.modules.storm_dataclass import StormArticle
6
+ import concurrent.futures
7
+ import json
8
+ import os
9
+ import pickle
10
+ import re
11
+ import sys
12
+ from typing import List, Dict
13
+
14
+ import httpx
15
+ import toml
16
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
17
+ from trafilatura import extract
18
+
19
+
20
+ class ArticleTextProcessing:
21
+ @staticmethod
22
+ def limit_word_count_preserve_newline(input_string, max_word_count):
23
+ """
24
+ Limit the word count of an input string to a specified maximum, while preserving the integrity of complete lines.
25
+
26
+ The function truncates the input string at the nearest word that does not exceed the maximum word count,
27
+ ensuring that no partial lines are included in the output. Words are defined as text separated by spaces,
28
+ and lines are defined as text separated by newline characters.
29
+
30
+ Args:
31
+ input_string (str): The string to be truncated. This string may contain multiple lines.
32
+ max_word_count (int): The maximum number of words allowed in the truncated string.
33
+
34
+ Returns:
35
+ str: The truncated string with word count limited to `max_word_count`, preserving complete lines.
36
+ """
37
+
38
+ word_count = 0
39
+ limited_string = ''
40
+
41
+ for word in input_string.split('\n'):
42
+ line_words = word.split()
43
+ for lw in line_words:
44
+ if word_count < max_word_count:
45
+ limited_string += lw + ' '
46
+ word_count += 1
47
+ else:
48
+ break
49
+ if word_count >= max_word_count:
50
+ break
51
+ limited_string = limited_string.strip() + '\n'
52
+
53
+ return limited_string.strip()
54
+
55
+ @staticmethod
56
+ def remove_citations(s):
57
+ """
58
+ Removes all citations from a given string. Citations are assumed to be in the format
59
+ of numbers enclosed in square brackets, such as [1], [2], or [1, 2], etc. This function searches
60
+ for all occurrences of such patterns and removes them, returning the cleaned string.
61
+
62
+ Args:
63
+ s (str): The string from which citations are to be removed.
64
+
65
+ Returns:
66
+ str: The string with all citation patterns removed.
67
+ """
68
+
69
+ return re.sub(r'\[\d+(?:,\s*\d+)*\]', '', s)
70
+
71
+ @staticmethod
72
+ def get_first_section_dict_and_list(s):
73
+ """
74
+ """
75
+ text = s
76
+ sections = text.strip().split('\n# ')
77
+ titles = []
78
+ content_dict = {}
79
+
80
+ for section in sections:
81
+ if section:
82
+ lines = section.split('\n', 1)
83
+ title = lines[0].strip()
84
+ content = lines[1].strip() if len(lines) > 1 else ""
85
+
86
+ titles.append(title)
87
+ content_dict[title] = content
88
+ return content_dict, titles
89
+
90
+ @staticmethod
91
+ def parse_citation_indices(s):
92
+ """
93
+ Extracts citation indexes from the provided content string and returns them as a list of integers.
94
+
95
+ Args:
96
+ content (str): The content string containing citations in the format [number].
97
+
98
+ Returns:
99
+ List[int]: A list of unique citation indexes extracted from the content, in the order they appear.
100
+ """
101
+ matches = re.findall(r'\[\d+\]', s)
102
+ return [int(index[1:-1]) for index in matches]
103
+
104
+ @staticmethod
105
+ def remove_uncompleted_sentences_with_citations(text):
106
+ """
107
+ Removes uncompleted sentences and standalone citations from the input text. Sentences are identified
108
+ by their ending punctuation (.!?), optionally followed by a citation in square brackets (e.g., "[1]").
109
+ Grouped citations (e.g., "[1, 2]") are split into individual ones (e.g., "[1] [2]"). Only text up to
110
+ and including the last complete sentence and its citation is retained.
111
+
112
+ Args:
113
+ text (str): The input text from which uncompleted sentences and their citations are to be removed.
114
+
115
+ Returns:
116
+ str: The processed string with uncompleted sentences and standalone citations removed, leaving only
117
+ complete sentences and their associated citations if present.
118
+ """
119
+
120
+ # Convert citations like [1, 2, 3] to [1][2][3].
121
+ def replace_with_individual_brackets(match):
122
+ numbers = match.group(1).split(', ')
123
+ return ' '.join(f'[{n}]' for n in numbers)
124
+
125
+ # Deduplicate and sort individual groups of citations.
126
+ def deduplicate_group(match):
127
+ citations = match.group(0)
128
+ unique_citations = list(set(re.findall(r'\[\d+\]', citations)))
129
+ sorted_citations = sorted(unique_citations, key=lambda x: int(x.strip('[]')))
130
+ # Return the sorted unique citations as a string
131
+ return ''.join(sorted_citations)
132
+
133
+ text = re.sub(r'\[([0-9, ]+)\]', replace_with_individual_brackets, text)
134
+ text = re.sub(r'(\[\d+\])+', deduplicate_group, text)
135
+
136
+ # Deprecated: Remove sentence without proper ending punctuation and citations.
137
+ # Split the text into sentences (including citations).
138
+ # sentences_with_trailing = re.findall(r'([^.!?]*[.!?].*?)(?=[^.!?]*[.!?]|$)', text)
139
+
140
+ # Filter sentences to ensure they end with a punctuation mark and properly formatted citations
141
+ # complete_sentences = []
142
+ # for sentence in sentences_with_trailing:
143
+ # # Check if the sentence ends with properly formatted citations
144
+ # if re.search(r'[.!?]( \[\d+\])*$|^[^.!?]*[.!?]$', sentence.strip()):
145
+ # complete_sentences.append(sentence.strip())
146
+
147
+ # combined_sentences = ' '.join(complete_sentences)
148
+
149
+ # Check for and append any complete citations that follow the last sentence
150
+ # trailing_citations = re.findall(r'(\[\d+\]) ', text[text.rfind(combined_sentences) + len(combined_sentences):])
151
+ # if trailing_citations:
152
+ # combined_sentences += ' '.join(trailing_citations)
153
+
154
+ # Regex pattern to match sentence endings, including optional citation markers.
155
+ eos_pattern = r'([.!?])\s*(\[\d+\])?\s*'
156
+ matches = list(re.finditer(eos_pattern, text))
157
+ if matches:
158
+ last_match = matches[-1]
159
+ text = text[:last_match.end()].strip()
160
+
161
+ return text
162
+
163
+ @staticmethod
164
+ def clean_up_citation(conv):
165
+ for turn in conv.dlg_history:
166
+ turn.agent_utterance = turn.agent_utterance[:turn.agent_utterance.find('References:')]
167
+ turn.agent_utterance = turn.agent_utterance[:turn.agent_utterance.find('Sources:')]
168
+ turn.agent_utterance = turn.agent_utterance.replace('Answer:', '').strip()
169
+ try:
170
+ max_ref_num = max([int(x) for x in re.findall(r'\[(\d+)\]', turn.agent_utterance)])
171
+ except Exception as e:
172
+ max_ref_num = 0
173
+ if max_ref_num > len(turn.search_results):
174
+ for i in range(len(turn.search_results), max_ref_num + 1):
175
+ turn.agent_utterance = turn.agent_utterance.replace(f'[{i}]', '')
176
+ turn.agent_utterance = ArticleTextProcessing.remove_uncompleted_sentences_with_citations(
177
+ turn.agent_utterance)
178
+
179
+ return conv
180
+
181
+ @staticmethod
182
+ def clean_up_outline(outline, topic=""):
183
+ output_lines = []
184
+ current_level = 0 # To track the current section level
185
+
186
+ for line in outline.split('\n'):
187
+ stripped_line = line.strip()
188
+
189
+ if topic != "" and f"# {topic.lower()}" in stripped_line.lower():
190
+ output_lines = []
191
+
192
+ # Check if the line is a section header
193
+ if stripped_line.startswith('#') and stripped_line != '#':
194
+ current_level = stripped_line.count('#')
195
+ output_lines.append(stripped_line)
196
+ # Check if the line is a bullet point
197
+ # elif stripped_line.startswith('-'):
198
+ # subsection_header = '#' * (current_level + 1) + ' ' + stripped_line[1:].strip()
199
+ # output_lines.append(subsection_header)
200
+ # Preserve lines with @
201
+ elif stripped_line.startswith('@'):
202
+ output_lines.append(stripped_line)
203
+
204
+ outline = '\n'.join(output_lines)
205
+
206
+ # Remove references.
207
+ outline = re.sub(r"#[#]? See also.*?(?=##|$)", '', outline, flags=re.DOTALL)
208
+ outline = re.sub(r"#[#]? See Also.*?(?=##|$)", '', outline, flags=re.DOTALL)
209
+ outline = re.sub(r"#[#]? Notes.*?(?=##|$)", '', outline, flags=re.DOTALL)
210
+ outline = re.sub(r"#[#]? References.*?(?=##|$)", '', outline, flags=re.DOTALL)
211
+ outline = re.sub(r"#[#]? External links.*?(?=##|$)", '', outline, flags=re.DOTALL)
212
+ outline = re.sub(r"#[#]? External Links.*?(?=##|$)", '', outline, flags=re.DOTALL)
213
+ outline = re.sub(r"#[#]? Bibliography.*?(?=##|$)", '', outline, flags=re.DOTALL)
214
+ outline = re.sub(r"#[#]? Further reading*?(?=##|$)", '', outline, flags=re.DOTALL)
215
+ outline = re.sub(r"#[#]? Further Reading*?(?=##|$)", '', outline, flags=re.DOTALL)
216
+ outline = re.sub(r"#[#]? Summary.*?(?=##|$)", '', outline, flags=re.DOTALL)
217
+ outline = re.sub(r"#[#]? Appendices.*?(?=##|$)", '', outline, flags=re.DOTALL)
218
+ outline = re.sub(r"#[#]? Appendix.*?(?=##|$)", '', outline, flags=re.DOTALL)
219
+
220
+ return outline
221
+
222
+
223
+ @staticmethod
224
+ def clean_up_section(text):
225
+ """Clean up a section:
226
+ 1. Remove uncompleted sentences (usually due to output token limitation).
227
+ 2. Deduplicate individual groups of citations.
228
+ 3. Remove unnecessary summary."""
229
+
230
+ paragraphs = text.split('\n')
231
+ output_paragraphs = []
232
+ summary_sec_flag = False
233
+ for p in paragraphs:
234
+ p = p.strip()
235
+ if len(p) == 0:
236
+ continue
237
+ if not p.startswith('#'):
238
+ p = ArticleTextProcessing.remove_uncompleted_sentences_with_citations(p)
239
+ if summary_sec_flag:
240
+ if p.startswith('#'):
241
+ summary_sec_flag = False
242
+ else:
243
+ continue
244
+ if p.startswith('Overall') or p.startswith('In summary') or p.startswith('In conclusion'):
245
+ continue
246
+ if "# Summary" in p or '# Conclusion' in p:
247
+ summary_sec_flag = True
248
+ continue
249
+ output_paragraphs.append(p)
250
+
251
+ return '\n\n'.join(output_paragraphs) # Join with '\n\n' for markdown format.
252
+
253
+ @staticmethod
254
+ def update_citation_index(s, citation_map):
255
+ """Update citation index in the string based on the citation map."""
256
+ for original_citation in citation_map:
257
+ s = s.replace(f"[{original_citation}]", f"__PLACEHOLDER_{original_citation}__")
258
+ for original_citation, unify_citation in citation_map.items():
259
+ s = s.replace(f"__PLACEHOLDER_{original_citation}__", f"[{unify_citation}]")
260
+
261
+ return s
262
+
263
+ @staticmethod
264
+ def parse_article_into_dict(input_string):
265
+ """
266
+ Parses a structured text into a nested dictionary. The structure of the text
267
+ is defined by markdown-like headers (using '#' symbols) to denote sections
268
+ and subsections. Each section can contain content and further nested subsections.
269
+
270
+ The resulting dictionary captures the hierarchical structure of sections, where
271
+ each section is represented as a key (the section's title) mapping to a value
272
+ that is another dictionary. This dictionary contains two keys:
273
+ - 'content': content of the section
274
+ - 'subsections': a list of dictionaries, each representing a nested subsection
275
+ following the same structure.
276
+
277
+ Args:
278
+ input_string (str): A string containing the structured text to parse.
279
+
280
+ Returns:
281
+ A dictionary representing contains the section title as the key, and another dictionary
282
+ as the value, which includes the 'content' and 'subsections' keys as described above.
283
+ """
284
+ lines = input_string.split('\n')
285
+ lines = [line for line in lines if line.strip()]
286
+ root = {'content': '', 'subsections': {}}
287
+ current_path = [(root, -1)] # (current_dict, level)
288
+
289
+ for line in lines:
290
+ if line.startswith('#'):
291
+ level = line.count('#')
292
+ title = line.strip('# ').strip()
293
+ new_section = {'content': '', 'subsections': {}}
294
+
295
+ # Pop from stack until find the parent level
296
+ while current_path and current_path[-1][1] >= level:
297
+ current_path.pop()
298
+
299
+ # Append new section to the nearest upper level's subsections
300
+ current_path[-1][0]['subsections'][title] = new_section
301
+ current_path.append((new_section, level))
302
+ else:
303
+ current_path[-1][0]['content'] += line + '\n'
304
+
305
+ return root['subsections']
306
+
307
+
308
+ class FileIOHelper:
309
+ @staticmethod
310
+ def dump_json(obj, file_name, encoding="utf-8"):
311
+ with open(file_name, 'w', encoding=encoding) as fw:
312
+ json.dump(obj, fw, default=FileIOHelper.handle_non_serializable, ensure_ascii=False)
313
+
314
+ @staticmethod
315
+ def handle_non_serializable(obj):
316
+ return "non-serializable contents" # mark the non-serializable part
317
+
318
+ @staticmethod
319
+ def load_json(file_name, encoding="utf-8"):
320
+ with open(file_name, 'r', encoding=encoding) as fr:
321
+ return json.load(fr)
322
+
323
+ @staticmethod
324
+ def write_str(s, path):
325
+ with open(path, 'w') as f:
326
+ f.write(s)
327
+
328
+ @staticmethod
329
+ def load_str(path):
330
+ with open(path, 'r') as f:
331
+ return '\n'.join(f.readlines())
332
+
333
+ @staticmethod
334
+ def dump_pickle(obj, path):
335
+ with open(path, 'wb') as f:
336
+ pickle.dump(obj, f)
337
+
338
+ @staticmethod
339
+ def load_pickle(path):
340
+ with open(path, 'rb') as f:
341
+ return pickle.load(f)
342
+
343
+ class ArticlePolishingModule():
344
+ """
345
+ The interface for article generation stage. Given topic, collected information from
346
+ knowledge curation stage, generated outline from outline generation stage.
347
+ """
348
+
349
+ def __init__(self,
350
+ article_gen_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel],
351
+ article_polish_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel]):
352
+ self.article_gen_lm = article_gen_lm
353
+ self.article_polish_lm = article_polish_lm
354
+
355
+ self.polish_page = PolishPageModule(
356
+ write_lead_engine=self.article_gen_lm,
357
+ polish_engine=self.article_polish_lm
358
+ )
359
+
360
+ def polish_article(self,
361
+ topic: str,
362
+ draft_article,
363
+ remove_duplicate: bool = False):
364
+ """
365
+ Polish article.
366
+
367
+ Args:
368
+ topic (str): The topic of the article.
369
+ draft_article (StormArticle): The draft article.
370
+ remove_duplicate (bool): Whether to use one additional LM call to remove duplicates from the article.
371
+ """
372
+
373
+ article_text = draft_article.to_string()
374
+ remove_duplicate = True
375
+ polish_result = self.polish_page(topic=topic, draft_page=article_text, polish_whole_page=remove_duplicate)
376
+
377
+ polished_article = polish_result.page
378
+
379
+ polished_article_dict = ArticleTextProcessing.parse_article_into_dict(polished_article)
380
+ polished_article = copy.deepcopy(draft_article)
381
+ polished_article.insert_or_create_section(article_dict=polished_article_dict)
382
+ polished_article.post_processing()
383
+ return polished_article
384
+
385
+
386
+
387
+ class PolishPage(dspy.Signature):
388
+ """
389
+ You are a faithful text editor that is good at finding repeated information in the article and deleting them to make sure there is no repetition in the article.
390
+ You won't delete any non-repeated part in the article.
391
+ You will keep the inline citations and article structure (indicated by "#", "##", etc.) appropriately.
392
+ In the article, do not include references.
393
+ Do your job for the following article.
394
+ """
395
+
396
+ article = dspy.InputField(prefix="The article you need to polish:\n", format=str)
397
+ page = dspy.OutputField(
398
+ prefix="Your revised article:\n",
399
+ format=str)
400
+
401
+
402
+ class PolishPageModule(dspy.Module):
403
+ def __init__(self, write_lead_engine: Union[dspy.dsp.LM, dspy.dsp.HFModel],
404
+ polish_engine: Union[dspy.dsp.LM, dspy.dsp.HFModel]):
405
+ super().__init__()
406
+ self.write_lead_engine = write_lead_engine
407
+ self.polish_engine = polish_engine
408
+ self.polish_page = dspy.Predict(PolishPage)
409
+
410
+ def forward(self, topic: str, draft_page: str, polish_whole_page: bool = True):
411
+
412
+ with dspy.settings.context(lm=self.polish_engine):
413
+ page = self.polish_page(article=draft_page).page
414
+
415
+ return dspy.Prediction(page=page)
416
+
417
+