om4r932 commited on
Commit
c1faac1
·
1 Parent(s): 0d2c020

Pull last commit app.py

Browse files
Files changed (1) hide show
  1. app.py +112 -58
app.py CHANGED
@@ -1,3 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from concurrent.futures import ThreadPoolExecutor, as_completed
2
  import json
3
  import traceback
@@ -10,40 +23,54 @@ from litellm.router import Router
10
  from aiolimiter import AsyncLimiter
11
  import pandas as pd
12
  import asyncio
 
13
  import re
14
  import nltk
15
 
 
 
 
 
 
 
 
 
16
  nltk.download('stopwords')
17
  nltk.download('punkt_tab')
18
  nltk.download('wordnet')
19
 
20
- from nltk.stem import WordNetLemmatizer
21
- from nltk.corpus import stopwords
22
- from nltk.tokenize import word_tokenize
23
-
24
- import string
25
- import subprocess
26
- import requests
27
- from dotenv import load_dotenv
28
-
29
- load_dotenv()
30
-
31
- import os
32
- from lxml import etree
33
- import zipfile
34
- import io
35
- import warnings
36
-
37
  warnings.filterwarnings("ignore")
38
 
39
- from bs4 import BeautifulSoup
40
-
41
  app = FastAPI(title="Requirements Extractor")
42
  app.mount("/static", StaticFiles(directory="static"), name="static")
43
- app.add_middleware(CORSMiddleware, allow_credentials=True, allow_headers=["*"], allow_methods=["*"], allow_origins=["*"])
44
- llm_router = Router(model_list=[{"model_name": "gemini-v1", "litellm_params": {"model": "gemini/gemini-2.0-flash", "api_key": os.environ.get("GEMINI"), "max_retries": 10, "rpm": 15}},
45
- {"model_name": "gemini-v2", "litellm_params": {"model": "gemini/gemini-2.5-flash", "api_key": os.environ.get("GEMINI"), "max_retries": 10, "rpm": 10}}]
46
- , fallbacks=[{"gemini-v2": ["gemini-v1"]}], num_retries=10)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  limiter_mapping = {
49
  model["model_name"]: AsyncLimiter(model["litellm_params"]["rpm"], 60)
@@ -56,15 +83,18 @@ NSMAP = {
56
  'v': 'urn:schemas-microsoft-com:vml'
57
  }
58
 
 
59
  def lemma(text: str):
60
  stop_words = set(stopwords.words('english'))
61
  txt = text.translate(str.maketrans('', '', string.punctuation)).strip()
62
- tokens = [token for token in word_tokenize(txt.lower()) if token not in stop_words]
 
63
  return [lemmatizer.lemmatize(token) for token in tokens]
64
 
 
65
  def get_docx_archive(url: str) -> zipfile.ZipFile:
66
  """Récupère le docx depuis l'URL et le retourne comme objet ZipFile"""
67
- if not url.endswith("zip"):
68
  raise ValueError("URL doit pointer vers un fichier ZIP")
69
  doc_id = os.path.splitext(os.path.basename(url))[0]
70
  resp = requests.get(url, verify=False, headers={
@@ -84,7 +114,7 @@ def get_docx_archive(url: str) -> zipfile.ZipFile:
84
 
85
  with open(input_path, "wb") as f:
86
  f.write(docx_bytes)
87
-
88
  subprocess.run([
89
  "libreoffice",
90
  "--headless",
@@ -98,17 +128,19 @@ def get_docx_archive(url: str) -> zipfile.ZipFile:
98
 
99
  os.remove(input_path)
100
  os.remove(output_path)
101
-
102
  return zipfile.ZipFile(io.BytesIO(docx_bytes))
103
 
104
  raise ValueError("Aucun fichier docx/doc trouvé dans l'archive")
105
 
 
106
  def parse_document_xml(docx_zip: zipfile.ZipFile) -> etree._ElementTree:
107
  """Parse le document.xml principal"""
108
  xml_bytes = docx_zip.read('word/document.xml')
109
  parser = etree.XMLParser(remove_blank_text=True)
110
  return etree.fromstring(xml_bytes, parser=parser)
111
 
 
112
  def clean_document_xml(root: etree._Element) -> None:
113
  """Nettoie le XML en modifiant l'arbre directement"""
114
  # Suppression des balises <w:del> et leur contenu
@@ -116,7 +148,7 @@ def clean_document_xml(root: etree._Element) -> None:
116
  parent = del_elem.getparent()
117
  if parent is not None:
118
  parent.remove(del_elem)
119
-
120
  # Désencapsulation des balises <w:ins>
121
  for ins_elem in root.xpath('//w:ins', namespaces=NSMAP):
122
  parent = ins_elem.getparent()
@@ -125,7 +157,7 @@ def clean_document_xml(root: etree._Element) -> None:
125
  parent.insert(index, child)
126
  index += 1
127
  parent.remove(ins_elem)
128
-
129
  # Nettoyage des commentaires
130
  for tag in ['w:commentRangeStart', 'w:commentRangeEnd', 'w:commentReference']:
131
  for elem in root.xpath(f'//{tag}', namespaces=NSMAP):
@@ -133,16 +165,17 @@ def clean_document_xml(root: etree._Element) -> None:
133
  if parent is not None:
134
  parent.remove(elem)
135
 
 
136
  def create_modified_docx(original_zip: zipfile.ZipFile, modified_root: etree._Element) -> bytes:
137
  """Crée un nouveau docx avec le XML modifié"""
138
  output = io.BytesIO()
139
-
140
  with zipfile.ZipFile(output, 'w', compression=zipfile.ZIP_DEFLATED) as new_zip:
141
  # Copier tous les fichiers non modifiés
142
  for file in original_zip.infolist():
143
  if file.filename != 'word/document.xml':
144
  new_zip.writestr(file, original_zip.read(file.filename))
145
-
146
  # Ajouter le document.xml modifié
147
  xml_str = etree.tostring(
148
  modified_root,
@@ -151,10 +184,11 @@ def create_modified_docx(original_zip: zipfile.ZipFile, modified_root: etree._El
151
  pretty_print=True
152
  )
153
  new_zip.writestr('word/document.xml', xml_str)
154
-
155
  output.seek(0)
156
  return output.getvalue()
157
 
 
158
  def docx_to_txt(doc_id: str, url: str):
159
  docx_zip = get_docx_archive(url)
160
  root = parse_document_xml(docx_zip)
@@ -165,7 +199,7 @@ def docx_to_txt(doc_id: str, url: str):
165
  output_path = f"/tmp/{doc_id}_cleaned.txt"
166
  with open(input_path, "wb") as f:
167
  f.write(modified_bytes)
168
-
169
  subprocess.run([
170
  "libreoffice",
171
  "--headless",
@@ -181,18 +215,20 @@ def docx_to_txt(doc_id: str, url: str):
181
  os.remove(output_path)
182
  return txt_data
183
 
 
184
  @app.get("/")
185
  def render_page():
186
  return FileResponse("index.html")
187
 
 
188
  @app.post("/get_meetings", response_model=MeetingsResponse)
189
  def get_meetings(req: MeetingsRequest):
190
  working_group = req.working_group
191
  tsg = re.sub(r"\d+", "", working_group)
192
  wg_number = re.search(r"\d", working_group).group(0)
193
- print(tsg, wg_number)
194
  url = "https://www.3gpp.org/ftp/tsg_" + tsg
195
- print(url)
196
  resp = requests.get(url, verify=False)
197
  soup = BeautifulSoup(resp.text, "html.parser")
198
  meeting_folders = []
@@ -205,22 +241,27 @@ def get_meetings(req: MeetingsRequest):
205
  break
206
 
207
  url += "/" + selected_folder
208
- print(url)
209
 
210
  if selected_folder:
211
  resp = requests.get(url, verify=False)
212
  soup = BeautifulSoup(resp.text, "html.parser")
213
- meeting_folders = [item.get_text() for item in soup.select("tr td a") if item.get_text().startswith("TSG") or (item.get_text().startswith("CT") and "-" in item.get_text())]
214
- all_meetings = [working_group + "#" + meeting.split("_", 1)[1].replace("_", " ").replace("-", " ") if meeting.startswith('TSG') else meeting.replace("-","#") for meeting in meeting_folders]
215
-
 
 
216
  return MeetingsResponse(meetings=dict(zip(all_meetings, meeting_folders)))
217
 
 
218
  @app.post("/get_dataframe", response_model=DataResponse)
219
  def get_change_request_dataframe(req: DataRequest):
220
  working_group = req.working_group
221
  tsg = re.sub(r"\d+", "", working_group)
222
  wg_number = re.search(r"\d", working_group).group(0)
223
  url = "https://www.3gpp.org/ftp/tsg_" + tsg
 
 
224
  resp = requests.get(url, verify=False)
225
  soup = BeautifulSoup(resp.text, "html.parser")
226
  wg_folders = [item.get_text() for item in soup.select("tr td a")]
@@ -233,18 +274,21 @@ def get_change_request_dataframe(req: DataRequest):
233
  url += "/" + selected_folder + "/" + req.meeting + "/docs"
234
  resp = requests.get(url, verify=False)
235
  soup = BeautifulSoup(resp.text, "html.parser")
236
- files = [item.get_text() for item in soup.select("tr td a") if item.get_text().endswith(".xlsx")]
 
237
 
238
  def gen_url(tdoc: str):
239
  return f"{url}/{tdoc}.zip"
240
 
241
  df = pd.read_excel(str(url + "/" + files[0]).replace("#", "%23"))
242
- filtered_df = df[(((df["Type"] == "CR") & ((df["CR category"] == "B") | (df["CR category"] == "C"))) | (df["Type"] == "pCR")) & ~(df["Uploaded"].isna())][["TDoc", "Title", "CR category", "Source", "Type", "Agenda item", "Agenda item description", "TDoc Status"]]
 
243
  filtered_df["URL"] = filtered_df["TDoc"].apply(gen_url)
244
 
245
  df = filtered_df.fillna("")
246
  return DataResponse(data=df[["TDoc", "Title", "Type", "TDoc Status", "Agenda item description", "URL"]].to_dict(orient="records"))
247
 
 
248
  @app.post("/download_tdocs")
249
  def download_tdocs(req: DownloadRequest):
250
  documents = req.documents
@@ -290,13 +334,17 @@ def download_tdocs(req: DownloadRequest):
290
  media_type="application/zip"
291
  )
292
 
 
293
  @app.post("/generate_requirements", response_model=RequirementsResponse)
294
  async def gen_reqs(req: RequirementsRequest, background_tasks: BackgroundTasks):
295
  documents = req.documents
296
  n_docs = len(documents)
 
 
 
297
  def prompt(doc_id, full):
298
  return f"Here's the document whose ID is {doc_id} : {full}\n\nExtract all requirements and group them by context, returning a list of objects where each object includes a document ID, a concise description of the context where the requirements apply (not a chapter title or copied text), and a list of associated requirements; always return the result as a list, even if only one context is found. Remove the errors"
299
-
300
  async def process_document(doc):
301
  doc_id = doc.document
302
  url = doc.url
@@ -305,13 +353,14 @@ async def gen_reqs(req: RequirementsRequest, background_tasks: BackgroundTasks):
305
  except Exception as e:
306
  traceback.print_exception(e)
307
  return RequirementsResponse(requirements=[DocRequirements(document=doc_id, context="Error LLM", requirements=[])]).requirements
308
-
309
  try:
310
  model_used = "gemini-v2" # À adapter si fallback activé
311
  async with limiter_mapping[model_used]:
312
  resp_ai = await llm_router.acompletion(
313
  model=model_used,
314
- messages=[{"role":"user","content": prompt(doc_id, full)}],
 
315
  response_format=RequirementsResponse
316
  )
317
  return RequirementsResponse.model_validate_json(resp_ai.choices[0].message.content).requirements
@@ -322,7 +371,8 @@ async def gen_reqs(req: RequirementsRequest, background_tasks: BackgroundTasks):
322
  async with limiter_mapping[model_used]:
323
  resp_ai = await llm_router.acompletion(
324
  model=model_used,
325
- messages=[{"role":"user","content": prompt(doc_id, full)}],
 
326
  response_format=RequirementsResponse
327
  )
328
  return RequirementsResponse.model_validate_json(resp_ai.choices[0].message.content).requirements
@@ -332,46 +382,50 @@ async def gen_reqs(req: RequirementsRequest, background_tasks: BackgroundTasks):
332
  else:
333
  traceback.print_exception(e)
334
  return RequirementsResponse(requirements=[DocRequirements(document=doc_id, context="Error LLM", requirements=[])]).requirements
335
-
336
  async def process_batch(batch):
337
  results = await asyncio.gather(*(process_document(doc) for doc in batch))
338
  return [item for sublist in results for item in sublist]
339
-
340
  all_requirements = []
341
-
342
  if n_docs <= 30:
343
  batch_results = await process_batch(documents)
344
  all_requirements.extend(batch_results)
345
  else:
346
  batch_size = 30
347
- batches = [documents[i:i + batch_size] for i in range(0, n_docs, batch_size)]
348
-
 
349
  for i, batch in enumerate(batches):
350
  batch_results = await process_batch(batch)
351
  all_requirements.extend(batch_results)
352
-
353
  if i < len(batches) - 1:
354
  background_tasks.add_task(asyncio.sleep, 60)
355
  return RequirementsResponse(requirements=all_requirements)
356
 
 
357
  @app.post("/get_reqs_from_query", response_model=ReqSearchResponse)
358
  def find_requirements_from_problem_description(req: ReqSearchRequest):
359
  requirements = req.requirements
360
  query = req.query
361
 
362
- requirements_text = "\n".join([f"[Selection ID: {r.req_id} | Document: {r.document} | Context: {r.context} | Requirement: {r.requirement}]" for r in requirements])
363
-
364
  print("Called the LLM")
365
  resp_ai = llm_router.completion(
366
  model="gemini-v2",
367
- messages=[{"role":"user","content": f"Given all the requirements : \n {requirements_text} \n and the problem description \"{query}\", return a list of 'Selection ID' for the most relevant corresponding requirements that reference or best cover the problem. If none of the requirements covers the problem, simply return an empty list"}],
368
  response_format=ReqSearchLLMResponse
369
  )
370
  print("Answered")
371
  print(resp_ai.choices[0].message.content)
372
-
373
- out_llm = ReqSearchLLMResponse.model_validate_json(resp_ai.choices[0].message.content).selected
 
374
  if max(out_llm) > len(requirements) - 1:
375
- raise HTTPException(status_code=500, detail="LLM error : Generated a wrong index, please try again.")
 
376
 
377
- return ReqSearchResponse(requirements=[requirements[i] for i in out_llm])
 
1
+ from bs4 import BeautifulSoup
2
+ import warnings
3
+ import io
4
+ import zipfile
5
+ from lxml import etree
6
+ import os
7
+ from dotenv import load_dotenv
8
+ import requests
9
+ import subprocess
10
+ import string
11
+ from nltk.tokenize import word_tokenize
12
+ from nltk.corpus import stopwords
13
+ from nltk.stem import WordNetLemmatizer
14
  from concurrent.futures import ThreadPoolExecutor, as_completed
15
  import json
16
  import traceback
 
23
  from aiolimiter import AsyncLimiter
24
  import pandas as pd
25
  import asyncio
26
+ import logging
27
  import re
28
  import nltk
29
 
30
+ load_dotenv()
31
+
32
+ logging.basicConfig(
33
+ level=logging.INFO,
34
+ format='[%(asctime)s][%(levelname)s][%(filename)s:%(lineno)d]: %(message)s',
35
+ datefmt='%Y-%m-%d %H:%M:%S'
36
+ )
37
+
38
  nltk.download('stopwords')
39
  nltk.download('punkt_tab')
40
  nltk.download('wordnet')
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  warnings.filterwarnings("ignore")
43
 
 
 
44
  app = FastAPI(title="Requirements Extractor")
45
  app.mount("/static", StaticFiles(directory="static"), name="static")
46
+ app.add_middleware(CORSMiddleware, allow_credentials=True, allow_headers=[
47
+ "*"], allow_methods=["*"], allow_origins=["*"])
48
+
49
+ llm_router = Router(model_list=[
50
+ {
51
+ "model_name": "gemini-v1",
52
+ "litellm_params":
53
+ {
54
+ "model": "gemini/gemini-2.0-flash",
55
+ "api_key": os.environ.get("GEMINI"),
56
+ "max_retries": 10,
57
+ "rpm": 15,
58
+ "allowed_fails": 1,
59
+ "cooldown": 30,
60
+ }
61
+ },
62
+ {
63
+ "model_name": "gemini-v2",
64
+ "litellm_params":
65
+ {
66
+ "model": "gemini/gemini-2.5-flash",
67
+ "api_key": os.environ.get("GEMINI"),
68
+ "max_retries": 10,
69
+ "rpm": 10,
70
+ "allowed_fails": 1,
71
+ "cooldown": 30,
72
+ }
73
+ }], fallbacks=[{"gemini-v2": ["gemini-v1"]}], num_retries=10, retry_after=30)
74
 
75
  limiter_mapping = {
76
  model["model_name"]: AsyncLimiter(model["litellm_params"]["rpm"], 60)
 
83
  'v': 'urn:schemas-microsoft-com:vml'
84
  }
85
 
86
+
87
  def lemma(text: str):
88
  stop_words = set(stopwords.words('english'))
89
  txt = text.translate(str.maketrans('', '', string.punctuation)).strip()
90
+ tokens = [token for token in word_tokenize(
91
+ txt.lower()) if token not in stop_words]
92
  return [lemmatizer.lemmatize(token) for token in tokens]
93
 
94
+
95
  def get_docx_archive(url: str) -> zipfile.ZipFile:
96
  """Récupère le docx depuis l'URL et le retourne comme objet ZipFile"""
97
+ if not url.endswith("zip"):
98
  raise ValueError("URL doit pointer vers un fichier ZIP")
99
  doc_id = os.path.splitext(os.path.basename(url))[0]
100
  resp = requests.get(url, verify=False, headers={
 
114
 
115
  with open(input_path, "wb") as f:
116
  f.write(docx_bytes)
117
+
118
  subprocess.run([
119
  "libreoffice",
120
  "--headless",
 
128
 
129
  os.remove(input_path)
130
  os.remove(output_path)
131
+
132
  return zipfile.ZipFile(io.BytesIO(docx_bytes))
133
 
134
  raise ValueError("Aucun fichier docx/doc trouvé dans l'archive")
135
 
136
+
137
  def parse_document_xml(docx_zip: zipfile.ZipFile) -> etree._ElementTree:
138
  """Parse le document.xml principal"""
139
  xml_bytes = docx_zip.read('word/document.xml')
140
  parser = etree.XMLParser(remove_blank_text=True)
141
  return etree.fromstring(xml_bytes, parser=parser)
142
 
143
+
144
  def clean_document_xml(root: etree._Element) -> None:
145
  """Nettoie le XML en modifiant l'arbre directement"""
146
  # Suppression des balises <w:del> et leur contenu
 
148
  parent = del_elem.getparent()
149
  if parent is not None:
150
  parent.remove(del_elem)
151
+
152
  # Désencapsulation des balises <w:ins>
153
  for ins_elem in root.xpath('//w:ins', namespaces=NSMAP):
154
  parent = ins_elem.getparent()
 
157
  parent.insert(index, child)
158
  index += 1
159
  parent.remove(ins_elem)
160
+
161
  # Nettoyage des commentaires
162
  for tag in ['w:commentRangeStart', 'w:commentRangeEnd', 'w:commentReference']:
163
  for elem in root.xpath(f'//{tag}', namespaces=NSMAP):
 
165
  if parent is not None:
166
  parent.remove(elem)
167
 
168
+
169
  def create_modified_docx(original_zip: zipfile.ZipFile, modified_root: etree._Element) -> bytes:
170
  """Crée un nouveau docx avec le XML modifié"""
171
  output = io.BytesIO()
172
+
173
  with zipfile.ZipFile(output, 'w', compression=zipfile.ZIP_DEFLATED) as new_zip:
174
  # Copier tous les fichiers non modifiés
175
  for file in original_zip.infolist():
176
  if file.filename != 'word/document.xml':
177
  new_zip.writestr(file, original_zip.read(file.filename))
178
+
179
  # Ajouter le document.xml modifié
180
  xml_str = etree.tostring(
181
  modified_root,
 
184
  pretty_print=True
185
  )
186
  new_zip.writestr('word/document.xml', xml_str)
187
+
188
  output.seek(0)
189
  return output.getvalue()
190
 
191
+
192
  def docx_to_txt(doc_id: str, url: str):
193
  docx_zip = get_docx_archive(url)
194
  root = parse_document_xml(docx_zip)
 
199
  output_path = f"/tmp/{doc_id}_cleaned.txt"
200
  with open(input_path, "wb") as f:
201
  f.write(modified_bytes)
202
+
203
  subprocess.run([
204
  "libreoffice",
205
  "--headless",
 
215
  os.remove(output_path)
216
  return txt_data
217
 
218
+
219
  @app.get("/")
220
  def render_page():
221
  return FileResponse("index.html")
222
 
223
+
224
  @app.post("/get_meetings", response_model=MeetingsResponse)
225
  def get_meetings(req: MeetingsRequest):
226
  working_group = req.working_group
227
  tsg = re.sub(r"\d+", "", working_group)
228
  wg_number = re.search(r"\d", working_group).group(0)
229
+ logging.debug(tsg, wg_number)
230
  url = "https://www.3gpp.org/ftp/tsg_" + tsg
231
+ logging.debug(url)
232
  resp = requests.get(url, verify=False)
233
  soup = BeautifulSoup(resp.text, "html.parser")
234
  meeting_folders = []
 
241
  break
242
 
243
  url += "/" + selected_folder
244
+ logging.debug(url)
245
 
246
  if selected_folder:
247
  resp = requests.get(url, verify=False)
248
  soup = BeautifulSoup(resp.text, "html.parser")
249
+ meeting_folders = [item.get_text() for item in soup.select("tr td a") if item.get_text(
250
+ ).startswith("TSG") or (item.get_text().startswith("CT") and "-" in item.get_text())]
251
+ all_meetings = [working_group + "#" + meeting.split("_", 1)[1].replace("_", " ").replace(
252
+ "-", " ") if meeting.startswith('TSG') else meeting.replace("-", "#") for meeting in meeting_folders]
253
+
254
  return MeetingsResponse(meetings=dict(zip(all_meetings, meeting_folders)))
255
 
256
+
257
  @app.post("/get_dataframe", response_model=DataResponse)
258
  def get_change_request_dataframe(req: DataRequest):
259
  working_group = req.working_group
260
  tsg = re.sub(r"\d+", "", working_group)
261
  wg_number = re.search(r"\d", working_group).group(0)
262
  url = "https://www.3gpp.org/ftp/tsg_" + tsg
263
+ logging.info("Fetching TDocs dataframe")
264
+
265
  resp = requests.get(url, verify=False)
266
  soup = BeautifulSoup(resp.text, "html.parser")
267
  wg_folders = [item.get_text() for item in soup.select("tr td a")]
 
274
  url += "/" + selected_folder + "/" + req.meeting + "/docs"
275
  resp = requests.get(url, verify=False)
276
  soup = BeautifulSoup(resp.text, "html.parser")
277
+ files = [item.get_text() for item in soup.select("tr td a")
278
+ if item.get_text().endswith(".xlsx")]
279
 
280
  def gen_url(tdoc: str):
281
  return f"{url}/{tdoc}.zip"
282
 
283
  df = pd.read_excel(str(url + "/" + files[0]).replace("#", "%23"))
284
+ filtered_df = df[(((df["Type"] == "CR") & ((df["CR category"] == "B") | (df["CR category"] == "C"))) | (df["Type"] == "pCR")) & ~(
285
+ df["Uploaded"].isna())][["TDoc", "Title", "CR category", "Source", "Type", "Agenda item", "Agenda item description", "TDoc Status"]]
286
  filtered_df["URL"] = filtered_df["TDoc"].apply(gen_url)
287
 
288
  df = filtered_df.fillna("")
289
  return DataResponse(data=df[["TDoc", "Title", "Type", "TDoc Status", "Agenda item description", "URL"]].to_dict(orient="records"))
290
 
291
+
292
  @app.post("/download_tdocs")
293
  def download_tdocs(req: DownloadRequest):
294
  documents = req.documents
 
334
  media_type="application/zip"
335
  )
336
 
337
+
338
  @app.post("/generate_requirements", response_model=RequirementsResponse)
339
  async def gen_reqs(req: RequirementsRequest, background_tasks: BackgroundTasks):
340
  documents = req.documents
341
  n_docs = len(documents)
342
+
343
+ logging.info("Generating requirements for documents: {}".format([doc.document for doc in documents]))
344
+
345
  def prompt(doc_id, full):
346
  return f"Here's the document whose ID is {doc_id} : {full}\n\nExtract all requirements and group them by context, returning a list of objects where each object includes a document ID, a concise description of the context where the requirements apply (not a chapter title or copied text), and a list of associated requirements; always return the result as a list, even if only one context is found. Remove the errors"
347
+
348
  async def process_document(doc):
349
  doc_id = doc.document
350
  url = doc.url
 
353
  except Exception as e:
354
  traceback.print_exception(e)
355
  return RequirementsResponse(requirements=[DocRequirements(document=doc_id, context="Error LLM", requirements=[])]).requirements
356
+
357
  try:
358
  model_used = "gemini-v2" # À adapter si fallback activé
359
  async with limiter_mapping[model_used]:
360
  resp_ai = await llm_router.acompletion(
361
  model=model_used,
362
+ messages=[
363
+ {"role": "user", "content": prompt(doc_id, full)}],
364
  response_format=RequirementsResponse
365
  )
366
  return RequirementsResponse.model_validate_json(resp_ai.choices[0].message.content).requirements
 
371
  async with limiter_mapping[model_used]:
372
  resp_ai = await llm_router.acompletion(
373
  model=model_used,
374
+ messages=[
375
+ {"role": "user", "content": prompt(doc_id, full)}],
376
  response_format=RequirementsResponse
377
  )
378
  return RequirementsResponse.model_validate_json(resp_ai.choices[0].message.content).requirements
 
382
  else:
383
  traceback.print_exception(e)
384
  return RequirementsResponse(requirements=[DocRequirements(document=doc_id, context="Error LLM", requirements=[])]).requirements
385
+
386
  async def process_batch(batch):
387
  results = await asyncio.gather(*(process_document(doc) for doc in batch))
388
  return [item for sublist in results for item in sublist]
389
+
390
  all_requirements = []
391
+
392
  if n_docs <= 30:
393
  batch_results = await process_batch(documents)
394
  all_requirements.extend(batch_results)
395
  else:
396
  batch_size = 30
397
+ batches = [documents[i:i + batch_size]
398
+ for i in range(0, n_docs, batch_size)]
399
+
400
  for i, batch in enumerate(batches):
401
  batch_results = await process_batch(batch)
402
  all_requirements.extend(batch_results)
403
+
404
  if i < len(batches) - 1:
405
  background_tasks.add_task(asyncio.sleep, 60)
406
  return RequirementsResponse(requirements=all_requirements)
407
 
408
+
409
  @app.post("/get_reqs_from_query", response_model=ReqSearchResponse)
410
  def find_requirements_from_problem_description(req: ReqSearchRequest):
411
  requirements = req.requirements
412
  query = req.query
413
 
414
+ requirements_text = "\n".join(
415
+ [f"[Selection ID: {r.req_id} | Document: {r.document} | Context: {r.context} | Requirement: {r.requirement}]" for r in requirements])
416
  print("Called the LLM")
417
  resp_ai = llm_router.completion(
418
  model="gemini-v2",
419
+ messages=[{"role": "user", "content": f"Given all the requirements : \n {requirements_text} \n and the problem description \"{query}\", return a list of 'Selection ID' for the most relevant corresponding requirements that reference or best cover the problem. If none of the requirements covers the problem, simply return an empty list"}],
420
  response_format=ReqSearchLLMResponse
421
  )
422
  print("Answered")
423
  print(resp_ai.choices[0].message.content)
424
+
425
+ out_llm = ReqSearchLLMResponse.model_validate_json(
426
+ resp_ai.choices[0].message.content).selected
427
  if max(out_llm) > len(requirements) - 1:
428
+ raise HTTPException(
429
+ status_code=500, detail="LLM error : Generated a wrong index, please try again.")
430
 
431
+ return ReqSearchResponse(requirements=[requirements[i] for i in out_llm])