Reqxtract-v2 / api /requirements.py
Lucas ARRIESSE
WIP solution drafting
e97be0e
raw
history blame
4.56 kB
import logging
from fastapi import APIRouter, Depends, HTTPException
from jinja2 import Environment
from litellm.router import Router
from dependencies import get_llm_router, get_prompt_templates
from schemas import _ReqGroupingCategory, _ReqGroupingOutput, ReqGroupingCategory, ReqGroupingRequest, ReqGroupingResponse, ReqSearchLLMResponse, ReqSearchRequest, ReqSearchResponse
# Router for requirement processing
router = APIRouter(tags=["requirement processing"])
@router.post("/get_reqs_from_query", response_model=ReqSearchResponse)
def find_requirements_from_problem_description(req: ReqSearchRequest, llm_router: Router = Depends(get_llm_router)):
"""Finds the requirements that adress a given problem description from an extracted list"""
requirements = req.requirements
query = req.query
requirements_text = "\n".join(
[f"[Selection ID: {r.req_id} | Document: {r.document} | Context: {r.context} | Requirement: {r.requirement}]" for r in requirements])
resp_ai = llm_router.completion(
model="gemini-v2",
messages=[{"role": "user", "content": f"Given all the requirements : \n {requirements_text} \n and the problem description \"{query}\", return a list of 'Selection ID' for the most relevant corresponding requirements that reference or best cover the problem. If none of the requirements covers the problem, simply return an empty list"}],
response_format=ReqSearchLLMResponse
)
out_llm = ReqSearchLLMResponse.model_validate_json(
resp_ai.choices[0].message.content).selected
logging.info(f"Found {len(out_llm)} reqs matching case.")
if max(out_llm) > len(requirements) - 1:
raise HTTPException(
status_code=500, detail="LLM error : Generated a wrong index, please try again.")
return ReqSearchResponse(requirements=[requirements[i] for i in out_llm])
@router.post("/categorize_requirements")
async def categorize_reqs(params: ReqGroupingRequest, prompt_env: Environment = Depends(get_prompt_templates), llm_router: Router = Depends(get_llm_router)) -> ReqGroupingResponse:
"""Categorize the given service requirements into categories"""
MAX_ATTEMPTS = 5
categories: list[_ReqGroupingCategory] = []
messages = []
# categorize the requirements using their indices
req_prompt = await prompt_env.get_template("classify.txt").render_async(**{
"requirements": [rq.model_dump() for rq in params.requirements],
"max_n_categories": params.max_n_categories,
"response_schema": _ReqGroupingOutput.model_json_schema()})
# add system prompt with requirements
messages.append({"role": "user", "content": req_prompt})
# ensure all requirements items are processed
for attempt in range(MAX_ATTEMPTS):
req_completion = await llm_router.acompletion(model="gemini-v2", messages=messages, response_format=_ReqGroupingOutput)
output = _ReqGroupingOutput.model_validate_json(
req_completion.choices[0].message.content)
# quick check to ensure no requirement was left out by the LLM by checking all IDs are contained in at least a single category
valid_ids_universe = set(range(0, len(params.requirements)))
assigned_ids = {
req_id for cat in output.categories for req_id in cat.items}
# keep only non-hallucinated, valid assigned ids
valid_assigned_ids = assigned_ids.intersection(valid_ids_universe)
# check for remaining requirements assigned to none of the categories
unassigned_ids = valid_ids_universe - valid_assigned_ids
if len(unassigned_ids) == 0:
categories.extend(output.categories)
break
else:
messages.append(req_completion.choices[0].message)
messages.append(
{"role": "user", "content": f"You haven't categorized the following requirements in at least one category {unassigned_ids}. Please do so."})
if attempt == MAX_ATTEMPTS - 1:
raise Exception("Failed to classify all requirements")
# build the final category objects
# remove the invalid (likely hallucinated) requirement IDs
final_categories = []
for idx, cat in enumerate(output.categories):
final_categories.append(ReqGroupingCategory(
id=idx,
title=cat.title,
requirements=[params.requirements[i]
for i in cat.items if i < len(params.requirements)]
))
return ReqGroupingResponse(categories=final_categories)