brichett commited on
Commit
280ee80
·
verified ·
1 Parent(s): 2dc3529

Update src/gradio_server.py

Browse files
Files changed (1) hide show
  1. src/gradio_server.py +137 -132
src/gradio_server.py CHANGED
@@ -1,133 +1,138 @@
1
- import gradio as gr
2
- from typing import Annotated
3
-
4
- from fastapi import FastAPI, Form, UploadFile
5
- from pydantic import BaseModel
6
- from hamilton import driver
7
- from pandas import DataFrame
8
-
9
- from data_module import data_pipeline, embedding_pipeline, vectorstore
10
- from classification_module import semantic_similarity, dio_support_detector
11
- from enforcement_module import policy_enforcement_decider
12
-
13
- from decouple import config
14
-
15
- app = FastAPI()
16
-
17
- config = {"loader": "pd",
18
- "embedding_service": "openai",
19
- "api_key": config("OPENAI_API_KEY"),
20
- "model_name": "text-embedding-ada-002",
21
- "mistral_public_url": config("MISTRAL_PUBLIC_URL"),
22
- "ner_public_url": config("NER_PUBLIC_URL")
23
- } # or "pd"
24
-
25
- dr = (
26
- driver.Builder()
27
- .with_config(config)
28
- .with_modules(data_pipeline, embedding_pipeline, vectorstore, semantic_similarity, dio_support_detector)
29
- .build()
30
- )
31
-
32
- dr_enforcement = (
33
- driver.Builder()
34
- .with_config(config)
35
- .with_modules(policy_enforcement_decider)
36
- .build()
37
- )
38
-
39
- class RadicalizationDetectionRequest(BaseModel):
40
- user_text: str
41
-
42
- class PolicyEnforcementRequest(BaseModel):
43
- user_text: str
44
- violation_context: dict
45
-
46
- class RadicalizationDetectionResponse(BaseModel):
47
- """Response to the /detect endpoint"""
48
- values: dict
49
-
50
- class PolicyEnforcementResponse(BaseModel):
51
- """Response to the /generate_policy_enforcement endpoint"""
52
- values: dict
53
-
54
- @app.post("/detect_radicalization")
55
- def detect_radicalization(
56
- request: RadicalizationDetectionRequest
57
- ) -> RadicalizationDetectionResponse:
58
-
59
- results = dr.execute(
60
- final_vars=["detect_glorification"],
61
- inputs={"project_root": ".", "user_input": request.user_text}
62
- )
63
- print(results)
64
- print(type(results))
65
- if isinstance(results, DataFrame):
66
- results = results.to_dict(orient="dict")
67
- return RadicalizationDetectionResponse(values=results)
68
-
69
- @app.post("/generate_policy_enforcement")
70
- def generate_policy_enforcement(
71
- request: PolicyEnforcementRequest
72
- ) -> PolicyEnforcementResponse:
73
- results = dr_enforcement.execute(
74
- final_vars=["get_enforcement_decision"],
75
- inputs={"project_root": ".", "user_input": request.user_text, "violation_context": request.violation_context}
76
- )
77
- print(results)
78
- print(type(results))
79
- if isinstance(results, DataFrame):
80
- results = results.to_dict(orient="dict")
81
- return PolicyEnforcementResponse(values=results)
82
-
83
- # Gradio Interface Functions
84
- def gradio_detect_radicalization(user_text: str):
85
- request = RadicalizationDetectionRequest(user_text=user_text)
86
- response = detect_radicalization(request)
87
- return response.values
88
-
89
- def gradio_generate_policy_enforcement(user_text: str, violation_context: str):
90
- # violation_context needs to be provided in a valid JSON format
91
- context_dict = eval(violation_context) # Replace eval with json.loads for safer parsing if it's JSON
92
- request = PolicyEnforcementRequest(user_text=user_text, violation_context=context_dict)
93
- response = generate_policy_enforcement(request)
94
- return response.values
95
-
96
- # Define the Gradio interface
97
- iface = gr.Interface(
98
- fn=gradio_detect_radicalization, # Function to detect radicalization
99
- inputs="text", # Single text input
100
- outputs="json", # Return JSON output
101
- title="Radicalization Detection",
102
- description="Enter text to detect glorification or radicalization."
103
- )
104
-
105
- # Second interface for policy enforcement
106
- iface2 = gr.Interface(
107
- fn=gradio_generate_policy_enforcement, # Function to generate policy enforcement
108
- inputs=["text", "text"], # Two text inputs, one for user text, one for violation context
109
- outputs="json", # Return JSON output
110
- title="Policy Enforcement Decision",
111
- description="Enter user text and context to generate a policy enforcement decision."
112
- )
113
-
114
- # Combine the interfaces in a Tabbed interface
115
- iface_combined = gr.TabbedInterface([iface, iface2], ["Detect Radicalization", "Policy Enforcement"])
116
-
117
- # Start the Gradio interface
118
- iface_combined.launch(server_name="0.0.0.0", server_port=7861)
119
-
120
-
121
- if __name__ == "__main__":
122
- import uvicorn
123
- from threading import Thread
124
-
125
- # Run FastAPI server in a separate thread
126
- def run_fastapi():
127
- uvicorn.run(app, host="0.0.0.0", port=8000)
128
-
129
- fastapi_thread = Thread(target=run_fastapi)
130
- fastapi_thread.start()
131
-
132
- # Launch Gradio Interface
 
 
 
 
 
133
  iface_combined.launch()
 
1
+ import gradio as gr
2
+ import os
3
+ import sys
4
+ from typing import Annotated
5
+
6
+ from fastapi import FastAPI, Form, UploadFile
7
+ from pydantic import BaseModel
8
+ from hamilton import driver
9
+ from pandas import DataFrame
10
+
11
+ # Add the src directory to the Python path
12
+ sys.path.append(os.path.join(os.path.dirname(__file__), 'src'))
13
+
14
+ from data_module import data_pipeline, embedding_pipeline, vectorstore
15
+ from classification_module import semantic_similarity, dio_support_detector
16
+ from enforcement_module import policy_enforcement_decider
17
+
18
+ from decouple import config
19
+
20
+ app = FastAPI()
21
+
22
+ config = {"loader": "pd",
23
+ "embedding_service": "openai",
24
+ "api_key": config("OPENAI_API_KEY"),
25
+ "model_name": "text-embedding-ada-002",
26
+ "mistral_public_url": config("MISTRAL_PUBLIC_URL"),
27
+ "ner_public_url": config("NER_PUBLIC_URL")
28
+ } # or "pd"
29
+
30
+ dr = (
31
+ driver.Builder()
32
+ .with_config(config)
33
+ .with_modules(data_pipeline, embedding_pipeline, vectorstore, semantic_similarity, dio_support_detector)
34
+ .build()
35
+ )
36
+
37
+ dr_enforcement = (
38
+ driver.Builder()
39
+ .with_config(config)
40
+ .with_modules(policy_enforcement_decider)
41
+ .build()
42
+ )
43
+
44
+ class RadicalizationDetectionRequest(BaseModel):
45
+ user_text: str
46
+
47
+ class PolicyEnforcementRequest(BaseModel):
48
+ user_text: str
49
+ violation_context: dict
50
+
51
+ class RadicalizationDetectionResponse(BaseModel):
52
+ """Response to the /detect endpoint"""
53
+ values: dict
54
+
55
+ class PolicyEnforcementResponse(BaseModel):
56
+ """Response to the /generate_policy_enforcement endpoint"""
57
+ values: dict
58
+
59
+ @app.post("/detect_radicalization")
60
+ def detect_radicalization(
61
+ request: RadicalizationDetectionRequest
62
+ ) -> RadicalizationDetectionResponse:
63
+
64
+ results = dr.execute(
65
+ final_vars=["detect_glorification"],
66
+ inputs={"project_root": ".", "user_input": request.user_text}
67
+ )
68
+ print(results)
69
+ print(type(results))
70
+ if isinstance(results, DataFrame):
71
+ results = results.to_dict(orient="dict")
72
+ return RadicalizationDetectionResponse(values=results)
73
+
74
+ @app.post("/generate_policy_enforcement")
75
+ def generate_policy_enforcement(
76
+ request: PolicyEnforcementRequest
77
+ ) -> PolicyEnforcementResponse:
78
+ results = dr_enforcement.execute(
79
+ final_vars=["get_enforcement_decision"],
80
+ inputs={"project_root": ".", "user_input": request.user_text, "violation_context": request.violation_context}
81
+ )
82
+ print(results)
83
+ print(type(results))
84
+ if isinstance(results, DataFrame):
85
+ results = results.to_dict(orient="dict")
86
+ return PolicyEnforcementResponse(values=results)
87
+
88
+ # Gradio Interface Functions
89
+ def gradio_detect_radicalization(user_text: str):
90
+ request = RadicalizationDetectionRequest(user_text=user_text)
91
+ response = detect_radicalization(request)
92
+ return response.values
93
+
94
+ def gradio_generate_policy_enforcement(user_text: str, violation_context: str):
95
+ # violation_context needs to be provided in a valid JSON format
96
+ context_dict = eval(violation_context) # Replace eval with json.loads for safer parsing if it's JSON
97
+ request = PolicyEnforcementRequest(user_text=user_text, violation_context=context_dict)
98
+ response = generate_policy_enforcement(request)
99
+ return response.values
100
+
101
+ # Define the Gradio interface
102
+ iface = gr.Interface(
103
+ fn=gradio_detect_radicalization, # Function to detect radicalization
104
+ inputs="text", # Single text input
105
+ outputs="json", # Return JSON output
106
+ title="Radicalization Detection",
107
+ description="Enter text to detect glorification or radicalization."
108
+ )
109
+
110
+ # Second interface for policy enforcement
111
+ iface2 = gr.Interface(
112
+ fn=gradio_generate_policy_enforcement, # Function to generate policy enforcement
113
+ inputs=["text", "text"], # Two text inputs, one for user text, one for violation context
114
+ outputs="json", # Return JSON output
115
+ title="Policy Enforcement Decision",
116
+ description="Enter user text and context to generate a policy enforcement decision."
117
+ )
118
+
119
+ # Combine the interfaces in a Tabbed interface
120
+ iface_combined = gr.TabbedInterface([iface, iface2], ["Detect Radicalization", "Policy Enforcement"])
121
+
122
+ # Start the Gradio interface
123
+ iface_combined.launch(server_name="0.0.0.0", server_port=7861)
124
+
125
+
126
+ if __name__ == "__main__":
127
+ import uvicorn
128
+ from threading import Thread
129
+
130
+ # Run FastAPI server in a separate thread
131
+ def run_fastapi():
132
+ uvicorn.run(app, host="0.0.0.0", port=8000)
133
+
134
+ fastapi_thread = Thread(target=run_fastapi)
135
+ fastapi_thread.start()
136
+
137
+ # Launch Gradio Interface
138
  iface_combined.launch()