Reqxtract-v2 / api /solutions.py
Lucas ARRIESSE
WIP solution drafting
e97be0e
raw
history blame
9.26 kB
import asyncio
import json
import logging
from fastapi import APIRouter, Depends, HTTPException, Response
from httpx import AsyncClient
from jinja2 import Environment, TemplateNotFound
from litellm.router import Router
from dependencies import INSIGHT_FINDER_BASE_URL, get_http_client, get_llm_router, get_prompt_templates
from typing import Awaitable, Callable, TypeVar
from schemas import _RefinedSolutionModel, _BootstrappedSolutionModel, _SolutionCriticismOutput, CriticizeSolutionsRequest, CritiqueResponse, InsightFinderConstraintsList, PriorArtSearchRequest, PriorArtSearchResponse, ReqGroupingCategory, ReqGroupingRequest, ReqGroupingResponse, ReqSearchLLMResponse, ReqSearchRequest, ReqSearchResponse, SolutionCriticism, SolutionModel, SolutionBootstrapResponse, SolutionBootstrapRequest, TechnologyData
# Router for solution generation and critique
router = APIRouter(tags=["solution generation and critique"])
# ============== utilities ===========================
T = TypeVar("T")
A = TypeVar("A")
async def retry_until(
func: Callable[[A], Awaitable[T]],
arg: A,
predicate: Callable[[T], bool],
max_retries: int,
) -> T:
"""Retries the given async function until the passed in validation predicate returns true."""
last_value = await func(arg)
for _ in range(max_retries):
if predicate(last_value):
return last_value
last_value = await func(arg)
return last_value
# =================================================== Search solutions ============================================================================
@router.post("/bootstrap_solutions")
async def bootstrap_solutions(req: SolutionBootstrapRequest, prompt_env: Environment = Depends(get_prompt_templates), llm_router: Router = Depends(get_llm_router), http_client: AsyncClient = Depends(get_http_client)) -> SolutionBootstrapResponse:
"""
Boostraps a solution for each of the passed in requirements categories using Insight Finder's API.
"""
async def _bootstrap_solution_inner(cat: ReqGroupingCategory):
# process requirements into insight finder format
fmt_completion = await llm_router.acompletion("gemini-v2", messages=[
{
"role": "user",
"content": await prompt_env.get_template("format_requirements.txt").render_async(**{
"category": cat.model_dump(),
"response_schema": InsightFinderConstraintsList.model_json_schema()
})
}], response_format=InsightFinderConstraintsList)
fmt_model = InsightFinderConstraintsList.model_validate_json(
fmt_completion.choices[0].message.content)
# translate from a structured output to a dict for insights finder
formatted_constraints = {'constraints': {
cons.title: cons.description for cons in fmt_model.constraints}}
# fetch technologies from insight finder
technologies_req = await http_client.post(INSIGHT_FINDER_BASE_URL + "process-constraints", content=json.dumps(formatted_constraints))
technologies = TechnologyData.model_validate(technologies_req.json())
# =============================================================== synthesize solution using LLM =========================================
format_solution = await llm_router.acompletion("gemini-v2", messages=[{
"role": "user",
"content": await prompt_env.get_template("bootstrap_solution.txt").render_async(**{
"category": cat.model_dump(),
"technologies": technologies.model_dump()["technologies"],
"user_constraints": req.user_constraints,
"response_schema": _BootstrappedSolutionModel.model_json_schema()
})}
], response_format=_BootstrappedSolutionModel)
format_solution_model = _BootstrappedSolutionModel.model_validate_json(
format_solution.choices[0].message.content)
final_solution = SolutionModel(
context="",
requirements=[
cat.requirements[i].requirement for i in format_solution_model.requirement_ids
],
problem_description=format_solution_model.problem_description,
solution_description=format_solution_model.solution_description,
references=[],
category_id=cat.id,
)
# ========================================================================================================================================
return final_solution
tasks = await asyncio.gather(*[_bootstrap_solution_inner(cat) for cat in req.categories], return_exceptions=True)
final_solutions = [sol for sol in tasks if not isinstance(sol, Exception)]
return SolutionBootstrapResponse(solutions=final_solutions)
@router.post("/criticize_solution", response_model=CritiqueResponse)
async def criticize_solution(params: CriticizeSolutionsRequest, prompt_env: Environment = Depends(get_prompt_templates), llm_router: Router = Depends(get_llm_router)) -> CritiqueResponse:
"""Criticize the challenges, weaknesses and limitations of the provided solutions."""
async def __criticize_single(solution: SolutionModel):
req_prompt = await prompt_env.get_template("criticize.txt").render_async(**{
"solutions": [solution.model_dump()],
"response_schema": _SolutionCriticismOutput.model_json_schema()
})
req_completion = await llm_router.acompletion(
model="gemini-v2",
messages=[{"role": "user", "content": req_prompt}],
response_format=_SolutionCriticismOutput
)
criticism_out = _SolutionCriticismOutput.model_validate_json(
req_completion.choices[0].message.content
)
return SolutionCriticism(solution=solution, criticism=criticism_out.criticisms[0])
critiques = await asyncio.gather(*[__criticize_single(sol) for sol in params.solutions], return_exceptions=False)
return CritiqueResponse(critiques=critiques)
# =================================================================== Refine solution ====================================
@router.post("/refine_solutions", response_model=SolutionBootstrapResponse)
async def refine_solutions(params: CritiqueResponse, prompt_env: Environment = Depends(get_prompt_templates), llm_router: Router = Depends(get_llm_router)) -> SolutionBootstrapResponse:
"""Refines the previously critiqued solutions."""
async def __refine_solution(crit: SolutionCriticism):
req_prompt = await prompt_env.get_template("refine_solution.txt").render_async(**{
"solution": crit.solution.model_dump(),
"criticism": crit.criticism,
"response_schema": _RefinedSolutionModel.model_json_schema(),
})
req_completion = await llm_router.acompletion(model="gemini-v2", messages=[
{"role": "user", "content": req_prompt}
], response_format=_RefinedSolutionModel)
req_model = _RefinedSolutionModel.model_validate_json(
req_completion.choices[0].message.content)
# copy previous solution model
refined_solution = crit.solution.model_copy(deep=True)
refined_solution.problem_description = req_model.problem_description
refined_solution.solution_description = req_model.solution_description
return refined_solution
refined_solutions = await asyncio.gather(*[__refine_solution(crit) for crit in params.critiques], return_exceptions=False)
return SolutionBootstrapResponse(solutions=refined_solutions)
@router.post("/search_prior_art")
async def search_prior_art(req: PriorArtSearchRequest, prompt_env: Environment = Depends(get_prompt_templates), llm_router: Router = Depends(get_llm_router)) -> PriorArtSearchResponse:
"""Performs a comprehensive prior art search / FTO search against the provided topics for a drafted solution"""
sema = asyncio.Semaphore(4)
async def __search_topic(topic: str) -> str:
search_prompt = await prompt_env.get_template("search/search_topic.txt").render_async(**{
"topic": topic
})
try:
await sema.acquire()
search_completion = await llm_router.acompletion(model="gemini-v2", messages=[
{"role": "user", "content": search_prompt}
], temperature=0.3, tools=[{"googleSearch": {}}])
return {"topic": topic, "content": search_completion.choices[0].message.content}
finally:
sema.release()
# Dispatch the individual tasks for topic search
topics = await asyncio.gather(*[__search_topic(top) for top in req.topics], return_exceptions=False)
consolidation_prompt = await prompt_env.get_template("search/build_final_report.txt").render_async(**{
"searches": topics
})
# Then consolidate everything into a single detailed topic
consolidation_completion = await llm_router.acompletion(model="gemini-v2", messages=[
{"role": "user", "content": consolidation_prompt}
], temperature=0.5)
return PriorArtSearchResponse(content=consolidation_completion.choices[0].message.content, references=[])