Spaces:
Runtime error
Runtime error
File size: 8,819 Bytes
06cb2a3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
"""
Tool for querying Neo4j about recent team news stories.
"""
import os
import sys
import re # Import regex for cleaning Cypher query
from dotenv import load_dotenv
from langchain_core.prompts import PromptTemplate
# from langchain_openai import ChatOpenAI # No longer needed directly here
# from langchain_community.chains.graph_qa.cypher import GraphCypherQAChain # Removed import
# Adjust path to import graph object and LLM from the parent directory
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if parent_dir not in sys.path:
sys.path.append(parent_dir)
try:
from gradio_graph import graph # Import the configured graph instance
from gradio_llm import llm # Import the configured LLM instance
except ImportError as e:
print(f"Error importing graph or llm: {e}")
print("Please ensure gradio_graph.py and gradio_llm.py exist and are configured correctly.")
sys.exit(1)
# Load environment variables if needed (though graph/llm should be configured)
load_dotenv()
# Define the prompt for translating NL query to Cypher for Team Story
CYPHER_TEAM_STORY_GENERATION_TEMPLATE = """
Task: Generate Cypher query to query a graph database for team news stories.
Instructions:
Use only the provided relationship types and properties in the schema.
Do not use any other relationship types or properties that are not provided.
Schema:
{schema}
Based on the schema, generate a Cypher query that retrieves relevant :Team_Story nodes based on the user's question.
* Focus on searching the `summary` and `topic` properties of the :Team_Story node (aliased as `s`).
* Always `MATCH (s:Team_Story)` and potentially relate it `MATCH (s)-[:STORY_ABOUT]->(t:Team {{name: 'San Francisco 49ers'}})` if the query implies 49ers context.
* Use `toLower()` for case-insensitive matching on properties like `topic` or keywords in `summary`.
* Return relevant properties like `s.summary`, `s.link_to_article`, `s.topic`.
* Limit the results to a reasonable number (e.g., LIMIT 10).
Note: Do not include any explanations or apologies in your responses.
Do not respond to any questions that might ask anything else than for you to construct a Cypher query.
Do not include any text except the generated Cypher query. Output ONLY the Cypher query.
The question is:
{query}
Cypher Query:
"""
CYPHER_TEAM_STORY_GENERATION_PROMPT = PromptTemplate(
input_variables=["schema", "query"], template=CYPHER_TEAM_STORY_GENERATION_TEMPLATE
)
# Placeholder for structured data caching
LAST_TEAM_STORY_DATA = []
def get_last_team_story_data():
"""Returns the structured data from the last team story query."""
return LAST_TEAM_STORY_DATA
def clean_cypher_query(query_text):
""" Basic cleaning of LLM-generated Cypher query. """
# Remove ```cypher ... ``` markdown fences if present
match = re.search(r"```(?:cypher)?\s*(.*?)\s*```", query_text, re.DOTALL | re.IGNORECASE)
if match:
query = match.group(1).strip()
else:
query = query_text.strip()
# Remove potential leading/trailing quotes if the LLM added them
query = query.strip('"\'')
return query
def team_story_qa(query: str) -> dict:
"""
Queries the Neo4j database for team news stories based on the user query.
Manually generates Cypher, executes it, and formats the results.
Args:
query: The natural language query from the user.
Returns:
A dictionary containing the 'output' text and structured 'team_story_data'.
"""
global LAST_TEAM_STORY_DATA
LAST_TEAM_STORY_DATA = [] # Clear previous results
structured_results = []
output_text = "Sorry, I encountered an error trying to find team news."
print(f"--- Running Team Story QA for query: {query} ---")
try:
# 1. Generate Cypher query using LLM
print("Generating Cypher query...")
cypher_generation_result = llm.invoke(
CYPHER_TEAM_STORY_GENERATION_PROMPT.format(
schema=graph.schema,
query=query
)
)
generated_cypher = cypher_generation_result.content # Extract text content
cleaned_cypher = clean_cypher_query(generated_cypher)
print(f"Generated Cypher (cleaned):\n{cleaned_cypher}")
# 2. Execute the generated Cypher query
if cleaned_cypher:
print("Executing Cypher query...")
# Assuming the generated query doesn't need parameters for now
# If parameters are needed, the prompt/parsing would need adjustment
neo4j_results = graph.query(cleaned_cypher)
print(f"Neo4j Results: {neo4j_results}")
# 3. Process results and extract structured data
if neo4j_results:
for record in neo4j_results:
# Check if record is a dictionary (expected from graph.query)
if isinstance(record, dict):
story_data = {
'summary': record.get('s.summary', 'Summary not available'),
'link_to_article': record.get('s.link_to_article', '#'),
'topic': record.get('s.topic', 'Topic not available')
}
# Basic check if data seems valid
if story_data['link_to_article'] != '#':
structured_results.append(story_data)
else:
print(f"Warning: Skipping unexpected record format: {record}")
else:
print("Warning: No Cypher query was generated.")
output_text = "I couldn't formulate a query to find the specific news you asked for."
# --- Limit the number of results stored and returned --- #
MAX_STORIES_TO_SHOW = 3
LAST_TEAM_STORY_DATA = structured_results[:MAX_STORIES_TO_SHOW]
# --- End limiting --- #
# 4. Format the text output based on the limited structured results
if not LAST_TEAM_STORY_DATA: # Check the potentially limited list now
# Keep default error or no-query message unless results were empty after valid query
if cleaned_cypher and not neo4j_results:
output_text = "I found no specific news articles matching your query in the database."
elif not cleaned_cypher:
pass # Keep the "couldn't formulate" message
else: # Error occurred during query execution or processing
pass # Keep the default error message
else:
# Base the text output on the *limited* list
output_text = "Here's what I found related to your query:\n\n"
for i, story in enumerate(LAST_TEAM_STORY_DATA): # Iterate over the limited list
output_text += f"{i+1}. {story['summary']}\n[Link: {story['link_to_article']}]\n\n"
# Optionally, mention if more were found originally (before limiting)
if len(structured_results) > MAX_STORIES_TO_SHOW:
output_text += f"... displaying the top {MAX_STORIES_TO_SHOW} of {len(structured_results)} relevant articles found."
except Exception as e:
import traceback
print(f"Error during team_story_qa: {e}")
print(traceback.format_exc()) # Print full traceback for debugging
output_text = "Sorry, I encountered an unexpected error trying to find team news."
LAST_TEAM_STORY_DATA = [] # Ensure cache is clear on error
print(f"--- Team Story QA output: {output_text} ---")
print(f"--- Team Story QA structured data: {LAST_TEAM_STORY_DATA} ---")
return {"output": output_text, "team_story_data": LAST_TEAM_STORY_DATA}
# Example usage (for testing)
if __name__ == '__main__':
# Ensure graph and llm are available for standalone testing if needed
print("Testing team_story_qa...")
test_query = "What is the latest news about the 49ers draft?"
print(f"\nTesting with query: {test_query}")
response = team_story_qa(test_query)
print("\nResponse Text:")
print(response.get("output"))
# print("\nStructured Data:")
# print(response.get("team_story_data"))
test_query_2 = "Any updates on the roster?"
print(f"\nTesting with query: {test_query_2}")
response_2 = team_story_qa(test_query_2)
print("\nResponse Text:")
print(response_2.get("output"))
# print("\nStructured Data:")
# print(response_2.get("team_story_data"))
test_query_3 = "Tell me about non-existent news"
print(f"\nTesting with query: {test_query_3}")
response_3 = team_story_qa(test_query_3)
print("\nResponse Text:")
print(response_3.get("output"))
# print("\nStructured Data:")
# print(response_3.get("team_story_data")) |