ppsingh commited on
Commit
78519a5
·
verified ·
1 Parent(s): 939ee59

Delete app/main.py

Browse files
Files changed (1) hide show
  1. app/main.py +0 -180
app/main.py DELETED
@@ -1,180 +0,0 @@
1
- import gradio as gr
2
- from gradio_client import Client
3
- from langgraph.graph import StateGraph, START, END
4
- from typing import TypedDict, Optional
5
- import io
6
- from PIL import Image
7
- import os
8
-
9
- #OPEN QUESTION: SHOULD WE PASS ALL PARAMS FROM THE ORCHESTRATOR TO THE NODES INSTEAD OF SETTING IN EACH MODULE?
10
- HF_TOKEN = os.environ.get("HF_TOKEN")
11
- # Define the state schema
12
- class GraphState(TypedDict):
13
- query: str
14
- context: str
15
- result: str
16
- # Add orchestrator-level parameters (addressing your open question)
17
- reports_filter: str
18
- sources_filter: str
19
- subtype_filter: str
20
- year_filter: str
21
-
22
- # node 2: retriever
23
- def retrieve_node(state: GraphState) -> GraphState:
24
- client = Client("giz/chatfed_retriever", hf_token=HF_TOKEN) # HF repo name
25
- context = client.predict(
26
- query=state["query"],
27
- reports_filter=state.get("reports_filter", ""),
28
- sources_filter=state.get("sources_filter", ""),
29
- subtype_filter=state.get("subtype_filter", ""),
30
- year_filter=state.get("year_filter", ""),
31
- api_name="/retrieve"
32
- )
33
- return {"context": context}
34
-
35
- # node 3: generator
36
- def generate_node(state: GraphState) -> GraphState:
37
- client = Client("giz/chatfed_generator", hf_token=HF_TOKEN)
38
- result = client.predict(
39
- query=state["query"],
40
- context=state["context"],
41
- api_name="/generate"
42
- )
43
- return {"result": result}
44
-
45
- # build the graph
46
- workflow = StateGraph(GraphState)
47
-
48
- # Add nodes
49
- workflow.add_node("retrieve", retrieve_node)
50
- workflow.add_node("generate", generate_node)
51
-
52
- # Add edges
53
- workflow.add_edge(START, "retrieve")
54
- workflow.add_edge("retrieve", "generate")
55
- workflow.add_edge("generate", END)
56
-
57
- # Compile the graph
58
- graph = workflow.compile()
59
-
60
- # Single tool for processing queries
61
- def process_query(
62
- query: str,
63
- reports_filter: str = "",
64
- sources_filter: str = "",
65
- subtype_filter: str = "",
66
- year_filter: str = ""
67
- ) -> str:
68
- """
69
- Execute the ChatFed orchestration pipeline to process a user query.
70
-
71
- This function orchestrates a two-step workflow:
72
- 1. Retrieve relevant context using the ChatFed retriever service with optional filters
73
- 2. Generate a response using the ChatFed generator service with the retrieved context
74
-
75
- Args:
76
- query (str): The user's input query/question to be processed
77
- reports_filter (str, optional): Filter for specific report types. Defaults to "".
78
- sources_filter (str, optional): Filter for specific data sources. Defaults to "".
79
- subtype_filter (str, optional): Filter for document subtypes. Defaults to "".
80
- year_filter (str, optional): Filter for specific years. Defaults to "".
81
-
82
- Returns:
83
- str: The generated response from the ChatFed generator service
84
- """
85
- initial_state = {
86
- "query": query,
87
- "context": "",
88
- "result": "",
89
- "reports_filter": reports_filter or "",
90
- "sources_filter": sources_filter or "",
91
- "subtype_filter": subtype_filter or "",
92
- "year_filter": year_filter or ""
93
- }
94
- final_state = graph.invoke(initial_state)
95
- return final_state["result"]
96
-
97
- # Simple testing interface
98
- ui = gr.Interface(
99
- fn=process_query,
100
- inputs=gr.Textbox(lines=2, placeholder="Enter query here"),
101
- outputs="text",
102
- flagging_mode="never"
103
- )
104
-
105
- # Add a function to generate the graph visualization
106
- def get_graph_visualization():
107
- """Generate and return the LangGraph workflow visualization as a PIL Image."""
108
- # Generate the graph as PNG bytes
109
- graph_png_bytes = graph.get_graph().draw_mermaid_png()
110
-
111
- # Convert bytes to PIL Image for Gradio display
112
- graph_image = Image.open(io.BytesIO(graph_png_bytes))
113
- return graph_image
114
-
115
-
116
- # Guidance for ChatUI - can be removed later. Questionable whether front end even necessary. Maybe nice to show the graph.
117
- with gr.Blocks(title="ChatFed Orchestrator") as demo:
118
- gr.Markdown("# ChatFed Orchestrator")
119
- gr.Markdown("This LangGraph server exposes MCP endpoints for the ChatUI module to call (which triggers the graph).")
120
-
121
- with gr.Row():
122
- # Left column - Graph visualization
123
- with gr.Column(scale=1):
124
- gr.Markdown("**Workflow Visualization**")
125
- graph_display = gr.Image(
126
- value=get_graph_visualization(),
127
- label="LangGraph Workflow",
128
- interactive=False,
129
- height=300
130
- )
131
-
132
- # Add a refresh button for the graph
133
- refresh_graph_btn = gr.Button("🔄 Refresh Graph", size="sm")
134
- refresh_graph_btn.click(
135
- fn=get_graph_visualization,
136
- outputs=graph_display
137
- )
138
-
139
- # Right column - Interface and documentation
140
- with gr.Column(scale=2):
141
- gr.Markdown("**Available MCP Tools:**")
142
-
143
- with gr.Accordion("MCP Endpoint Information", open=True):
144
- gr.Markdown(f"""
145
- **MCP Server Endpoint:** https://giz-chatfed-orchestrator.hf.space/gradio_api/mcp/sse
146
-
147
- **For ChatUI Integration:**
148
- ```python
149
- from gradio_client import Client
150
-
151
- # Connect to orchestrator
152
- orchestrator_client = Client("https://giz-chatfed-orchestrator.hf.space")
153
-
154
- # Basic usage (no filters)
155
- response = orchestrator_client.predict(
156
- query="query",
157
- api_name="/process_query"
158
- )
159
-
160
- # Advanced usage with any combination of filters
161
- response = orchestrator_client.predict(
162
- query="query",
163
- reports_filter="annual_reports",
164
- sources_filter="internal",
165
- year_filter="2024",
166
- api_name="/process_query"
167
- )
168
- ```
169
- """)
170
-
171
- with gr.Accordion("Quick Testing Interface", open=True):
172
- ui.render()
173
-
174
- if __name__ == "__main__":
175
- demo.launch(
176
- server_name="0.0.0.0",
177
- server_port=7860,
178
- mcp_server=True,
179
- show_error=True
180
- )