Spaces:
Running
Running
Update workflow.py
Browse files- workflow.py +40 -26
workflow.py
CHANGED
|
@@ -23,18 +23,13 @@ class AgentState(TypedDict):
|
|
| 23 |
|
| 24 |
class ResearchWorkflow:
|
| 25 |
"""
|
| 26 |
-
A multi-step research workflow
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
- Legal Research
|
| 30 |
-
- Environmental and Energy Studies
|
| 31 |
-
- Competitive Programming and Theoretical Computer Science
|
| 32 |
-
- Social Sciences
|
| 33 |
-
This implementation normalizes the domain and uses domain-specific prompts and fallbacks.
|
| 34 |
"""
|
| 35 |
def __init__(self) -> None:
|
| 36 |
self.processor = EnhancedCognitiveProcessor()
|
| 37 |
-
self.workflow = StateGraph(AgentState)
|
| 38 |
self._build_workflow()
|
| 39 |
self.app = self.workflow.compile()
|
| 40 |
|
|
@@ -44,6 +39,8 @@ class ResearchWorkflow:
|
|
| 44 |
self.workflow.add_node("analyze", self.analyze_content)
|
| 45 |
self.workflow.add_node("validate", self.validate_output)
|
| 46 |
self.workflow.add_node("refine", self.refine_results)
|
|
|
|
|
|
|
| 47 |
self.workflow.set_entry_point("ingest")
|
| 48 |
self.workflow.add_edge("ingest", "retrieve")
|
| 49 |
self.workflow.add_edge("retrieve", "analyze")
|
|
@@ -52,17 +49,17 @@ class ResearchWorkflow:
|
|
| 52 |
self._quality_check,
|
| 53 |
{"valid": "validate", "invalid": "refine"}
|
| 54 |
)
|
| 55 |
-
self.workflow.add_edge("validate",
|
| 56 |
self.workflow.add_edge("refine", "retrieve")
|
| 57 |
# Extended node for multi-modal enhancement
|
| 58 |
self.workflow.add_node("enhance", self.enhance_analysis)
|
| 59 |
-
self.workflow.add_edge("
|
| 60 |
self.workflow.add_edge("enhance", END)
|
| 61 |
|
| 62 |
def ingest_query(self, state: Dict) -> Dict:
|
| 63 |
try:
|
| 64 |
query = state["messages"][-1].content
|
| 65 |
-
# Normalize domain string
|
| 66 |
domain = state.get("context", {}).get("domain", "Biomedical Research").strip().lower()
|
| 67 |
new_context = {
|
| 68 |
"raw_query": query,
|
|
@@ -83,7 +80,7 @@ class ResearchWorkflow:
|
|
| 83 |
def retrieve_documents(self, state: Dict) -> Dict:
|
| 84 |
try:
|
| 85 |
query = state["context"]["raw_query"]
|
| 86 |
-
#
|
| 87 |
docs = []
|
| 88 |
logger.info(f"Retrieved {len(docs)} documents for query.")
|
| 89 |
return {
|
|
@@ -102,18 +99,16 @@ class ResearchWorkflow:
|
|
| 102 |
|
| 103 |
def analyze_content(self, state: Dict) -> Dict:
|
| 104 |
try:
|
| 105 |
-
# Normalize domain and use it for prompt generation
|
| 106 |
domain = state["context"].get("domain", "biomedical research").strip().lower()
|
| 107 |
docs = state["context"].get("documents", [])
|
| 108 |
-
# Use retrieved documents if available; else, use raw query as fallback.
|
| 109 |
if docs:
|
| 110 |
docs_text = "\n\n".join([d.page_content for d in docs])
|
| 111 |
else:
|
| 112 |
docs_text = state["context"].get("raw_query", "")
|
| 113 |
-
logger.info("No documents retrieved;
|
| 114 |
-
#
|
| 115 |
-
domain_prompt = ResearchConfig.DOMAIN_PROMPTS.get(domain,
|
| 116 |
-
|
| 117 |
full_prompt = f"Domain: {state['context'].get('domain', 'Biomedical Research')}\n" \
|
| 118 |
f"{domain_prompt}\n\n" + \
|
| 119 |
ResearchConfig.ANALYSIS_TEMPLATE.format(context=docs_text)
|
|
@@ -134,10 +129,11 @@ class ResearchWorkflow:
|
|
| 134 |
try:
|
| 135 |
analysis = state["messages"][-1].content
|
| 136 |
validation_prompt = (
|
| 137 |
-
f"Validate the following analysis for
|
| 138 |
"Criteria:\n"
|
| 139 |
-
"1.
|
| 140 |
-
"2.
|
|
|
|
| 141 |
"3. Logical consistency\n"
|
| 142 |
"4. Methodological soundness\n\n"
|
| 143 |
"Respond with 'VALID: [justification]' or 'INVALID: [justification]'."
|
|
@@ -152,6 +148,26 @@ class ResearchWorkflow:
|
|
| 152 |
logger.exception("Error during output validation.")
|
| 153 |
return self._error_state(f"Validation Error: {str(e)}")
|
| 154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
def refine_results(self, state: Dict) -> Dict:
|
| 156 |
try:
|
| 157 |
current_count = state["context"].get("refine_count", 0)
|
|
@@ -167,8 +183,7 @@ class ResearchWorkflow:
|
|
| 167 |
f"Domain: {domain}\n"
|
| 168 |
"You are given the following series of refinement outputs:\n" +
|
| 169 |
"\n---\n".join(refinement_history) +
|
| 170 |
-
"\n\nSynthesize these into a final, concise
|
| 171 |
-
"Focus on improving accuracy and relevance for legal research."
|
| 172 |
)
|
| 173 |
meta_response = self.processor.process_query(meta_prompt)
|
| 174 |
logger.info("Meta-refinement completed.")
|
|
@@ -180,8 +195,7 @@ class ResearchWorkflow:
|
|
| 180 |
refinement_prompt = (
|
| 181 |
f"Domain: {domain}\n"
|
| 182 |
f"Refine this analysis (current difficulty level: {difficulty_level}):\n{current_analysis}\n\n"
|
| 183 |
-
"
|
| 184 |
-
"Then, improve the analysis with clear references to legal precedents and statutory language."
|
| 185 |
)
|
| 186 |
response = self.processor.process_query(refinement_prompt)
|
| 187 |
logger.info("Refinement completed.")
|
|
|
|
| 23 |
|
| 24 |
class ResearchWorkflow:
|
| 25 |
"""
|
| 26 |
+
A multi-step research workflow employing Retrieval-Augmented Generation (RAG) with an additional verification step.
|
| 27 |
+
This workflow supports multiple domains (e.g., Biomedical, Legal, Environmental, Competitive Programming, Social Sciences)
|
| 28 |
+
and integrates domain-specific prompts, iterative refinement, and a final verification to reduce hallucinations.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
"""
|
| 30 |
def __init__(self) -> None:
|
| 31 |
self.processor = EnhancedCognitiveProcessor()
|
| 32 |
+
self.workflow = StateGraph(AgentState)
|
| 33 |
self._build_workflow()
|
| 34 |
self.app = self.workflow.compile()
|
| 35 |
|
|
|
|
| 39 |
self.workflow.add_node("analyze", self.analyze_content)
|
| 40 |
self.workflow.add_node("validate", self.validate_output)
|
| 41 |
self.workflow.add_node("refine", self.refine_results)
|
| 42 |
+
# New verify node to further cross-check the output
|
| 43 |
+
self.workflow.add_node("verify", self.verify_output)
|
| 44 |
self.workflow.set_entry_point("ingest")
|
| 45 |
self.workflow.add_edge("ingest", "retrieve")
|
| 46 |
self.workflow.add_edge("retrieve", "analyze")
|
|
|
|
| 49 |
self._quality_check,
|
| 50 |
{"valid": "validate", "invalid": "refine"}
|
| 51 |
)
|
| 52 |
+
self.workflow.add_edge("validate", "verify")
|
| 53 |
self.workflow.add_edge("refine", "retrieve")
|
| 54 |
# Extended node for multi-modal enhancement
|
| 55 |
self.workflow.add_node("enhance", self.enhance_analysis)
|
| 56 |
+
self.workflow.add_edge("verify", "enhance")
|
| 57 |
self.workflow.add_edge("enhance", END)
|
| 58 |
|
| 59 |
def ingest_query(self, state: Dict) -> Dict:
|
| 60 |
try:
|
| 61 |
query = state["messages"][-1].content
|
| 62 |
+
# Normalize the domain string; default to 'biomedical research'
|
| 63 |
domain = state.get("context", {}).get("domain", "Biomedical Research").strip().lower()
|
| 64 |
new_context = {
|
| 65 |
"raw_query": query,
|
|
|
|
| 80 |
def retrieve_documents(self, state: Dict) -> Dict:
|
| 81 |
try:
|
| 82 |
query = state["context"]["raw_query"]
|
| 83 |
+
# Placeholder retrieval: currently returns an empty list (simulate no documents)
|
| 84 |
docs = []
|
| 85 |
logger.info(f"Retrieved {len(docs)} documents for query.")
|
| 86 |
return {
|
|
|
|
| 99 |
|
| 100 |
def analyze_content(self, state: Dict) -> Dict:
|
| 101 |
try:
|
|
|
|
| 102 |
domain = state["context"].get("domain", "biomedical research").strip().lower()
|
| 103 |
docs = state["context"].get("documents", [])
|
|
|
|
| 104 |
if docs:
|
| 105 |
docs_text = "\n\n".join([d.page_content for d in docs])
|
| 106 |
else:
|
| 107 |
docs_text = state["context"].get("raw_query", "")
|
| 108 |
+
logger.info("No documents retrieved; switching to dynamic synthesis (RAG mode).")
|
| 109 |
+
# Use domain-specific prompt; for legal research, inject legal-specific guidance.
|
| 110 |
+
domain_prompt = ResearchConfig.DOMAIN_PROMPTS.get(domain,
|
| 111 |
+
"Provide an analysis based on the provided context.")
|
| 112 |
full_prompt = f"Domain: {state['context'].get('domain', 'Biomedical Research')}\n" \
|
| 113 |
f"{domain_prompt}\n\n" + \
|
| 114 |
ResearchConfig.ANALYSIS_TEMPLATE.format(context=docs_text)
|
|
|
|
| 129 |
try:
|
| 130 |
analysis = state["messages"][-1].content
|
| 131 |
validation_prompt = (
|
| 132 |
+
f"Validate the following analysis for accuracy and domain-specific relevance:\n{analysis}\n\n"
|
| 133 |
"Criteria:\n"
|
| 134 |
+
"1. Factual and technical accuracy\n"
|
| 135 |
+
"2. For legal research: inclusion of relevant precedents and statutory interpretations; "
|
| 136 |
+
"for other domains: appropriate domain insights\n"
|
| 137 |
"3. Logical consistency\n"
|
| 138 |
"4. Methodological soundness\n\n"
|
| 139 |
"Respond with 'VALID: [justification]' or 'INVALID: [justification]'."
|
|
|
|
| 148 |
logger.exception("Error during output validation.")
|
| 149 |
return self._error_state(f"Validation Error: {str(e)}")
|
| 150 |
|
| 151 |
+
def verify_output(self, state: Dict) -> Dict:
|
| 152 |
+
try:
|
| 153 |
+
# New verify step: cross-check the analysis using an external fact-checking prompt.
|
| 154 |
+
analysis = state["messages"][-1].content
|
| 155 |
+
verification_prompt = (
|
| 156 |
+
f"Verify the following analysis by comparing it with established external legal databases and reference texts:\n{analysis}\n\n"
|
| 157 |
+
"Identify any discrepancies or hallucinations and provide a brief correction if necessary."
|
| 158 |
+
)
|
| 159 |
+
response = self.processor.process_query(verification_prompt)
|
| 160 |
+
logger.info("Output verification completed.")
|
| 161 |
+
# Here, you can merge the verification feedback with the analysis.
|
| 162 |
+
verified_analysis = analysis + "\n\nVerification Feedback: " + response.get('choices', [{}])[0].get('message', {}).get('content', '')
|
| 163 |
+
return {
|
| 164 |
+
"messages": [AIMessage(content=verified_analysis)],
|
| 165 |
+
"context": state["context"]
|
| 166 |
+
}
|
| 167 |
+
except Exception as e:
|
| 168 |
+
logger.exception("Error during output verification.")
|
| 169 |
+
return self._error_state(f"Verification Error: {str(e)}")
|
| 170 |
+
|
| 171 |
def refine_results(self, state: Dict) -> Dict:
|
| 172 |
try:
|
| 173 |
current_count = state["context"].get("refine_count", 0)
|
|
|
|
| 183 |
f"Domain: {domain}\n"
|
| 184 |
"You are given the following series of refinement outputs:\n" +
|
| 185 |
"\n---\n".join(refinement_history) +
|
| 186 |
+
"\n\nSynthesize these into a final, concise analysis report with improved accuracy and verifiable details."
|
|
|
|
| 187 |
)
|
| 188 |
meta_response = self.processor.process_query(meta_prompt)
|
| 189 |
logger.info("Meta-refinement completed.")
|
|
|
|
| 195 |
refinement_prompt = (
|
| 196 |
f"Domain: {domain}\n"
|
| 197 |
f"Refine this analysis (current difficulty level: {difficulty_level}):\n{current_analysis}\n\n"
|
| 198 |
+
"Identify and correct any weaknesses or hallucinations in the analysis, providing verifiable details."
|
|
|
|
| 199 |
)
|
| 200 |
response = self.processor.process_query(refinement_prompt)
|
| 201 |
logger.info("Refinement completed.")
|