Update src/gradio_server.py
Browse files- src/gradio_server.py +24 -17
src/gradio_server.py
CHANGED
|
@@ -1,12 +1,11 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
import os
|
| 3 |
import sys
|
| 4 |
-
from
|
| 5 |
-
|
| 6 |
-
from fastapi import FastAPI, Form, UploadFile
|
| 7 |
from pydantic import BaseModel
|
| 8 |
from hamilton import driver
|
| 9 |
from pandas import DataFrame
|
|
|
|
| 10 |
|
| 11 |
# Add the src directory to the Python path
|
| 12 |
sys.path.append(os.path.join(os.path.dirname(__file__), 'src'))
|
|
@@ -19,13 +18,23 @@ from decouple import config
|
|
| 19 |
|
| 20 |
app = FastAPI()
|
| 21 |
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
dr = (
|
| 31 |
driver.Builder()
|
|
@@ -49,13 +58,11 @@ class PolicyEnforcementRequest(BaseModel):
|
|
| 49 |
violation_context: dict
|
| 50 |
|
| 51 |
class RadicalizationDetectionResponse(BaseModel):
|
| 52 |
-
"""Response to the /detect endpoint"""
|
| 53 |
values: dict
|
| 54 |
|
| 55 |
class PolicyEnforcementResponse(BaseModel):
|
| 56 |
-
"""Response to the /generate_policy_enforcement endpoint"""
|
| 57 |
values: dict
|
| 58 |
-
|
| 59 |
@app.post("/detect_radicalization")
|
| 60 |
def detect_radicalization(
|
| 61 |
request: RadicalizationDetectionRequest
|
|
@@ -65,8 +72,6 @@ def detect_radicalization(
|
|
| 65 |
final_vars=["detect_glorification"],
|
| 66 |
inputs={"project_root": ".", "user_input": request.user_text}
|
| 67 |
)
|
| 68 |
-
print(results)
|
| 69 |
-
print(type(results))
|
| 70 |
if isinstance(results, DataFrame):
|
| 71 |
results = results.to_dict(orient="dict")
|
| 72 |
return RadicalizationDetectionResponse(values=results)
|
|
@@ -79,8 +84,6 @@ def generate_policy_enforcement(
|
|
| 79 |
final_vars=["get_enforcement_decision"],
|
| 80 |
inputs={"project_root": ".", "user_input": request.user_text, "violation_context": request.violation_context}
|
| 81 |
)
|
| 82 |
-
print(results)
|
| 83 |
-
print(type(results))
|
| 84 |
if isinstance(results, DataFrame):
|
| 85 |
results = results.to_dict(orient="dict")
|
| 86 |
return PolicyEnforcementResponse(values=results)
|
|
@@ -118,3 +121,7 @@ iface2 = gr.Interface(
|
|
| 118 |
|
| 119 |
# Combine the interfaces in a Tabbed interface
|
| 120 |
iface_combined = gr.TabbedInterface([iface, iface2], ["Detect Radicalization", "Policy Enforcement"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import os
|
| 3 |
import sys
|
| 4 |
+
from fastapi import FastAPI
|
|
|
|
|
|
|
| 5 |
from pydantic import BaseModel
|
| 6 |
from hamilton import driver
|
| 7 |
from pandas import DataFrame
|
| 8 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 9 |
|
| 10 |
# Add the src directory to the Python path
|
| 11 |
sys.path.append(os.path.join(os.path.dirname(__file__), 'src'))
|
|
|
|
| 18 |
|
| 19 |
app = FastAPI()
|
| 20 |
|
| 21 |
+
# Enable CORS for Gradio to communicate with FastAPI
|
| 22 |
+
app.add_middleware(
|
| 23 |
+
CORSMiddleware,
|
| 24 |
+
allow_origins=["*"],
|
| 25 |
+
allow_credentials=True,
|
| 26 |
+
allow_methods=["*"],
|
| 27 |
+
allow_headers=["*"],
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
config = {
|
| 31 |
+
"loader": "pd",
|
| 32 |
+
"embedding_service": "openai",
|
| 33 |
+
"api_key": config("OPENAI_API_KEY"),
|
| 34 |
+
"model_name": "text-embedding-ada-002",
|
| 35 |
+
"mistral_public_url": config("MISTRAL_PUBLIC_URL"),
|
| 36 |
+
"ner_public_url": config("NER_PUBLIC_URL"),
|
| 37 |
+
}
|
| 38 |
|
| 39 |
dr = (
|
| 40 |
driver.Builder()
|
|
|
|
| 58 |
violation_context: dict
|
| 59 |
|
| 60 |
class RadicalizationDetectionResponse(BaseModel):
|
|
|
|
| 61 |
values: dict
|
| 62 |
|
| 63 |
class PolicyEnforcementResponse(BaseModel):
|
|
|
|
| 64 |
values: dict
|
| 65 |
+
|
| 66 |
@app.post("/detect_radicalization")
|
| 67 |
def detect_radicalization(
|
| 68 |
request: RadicalizationDetectionRequest
|
|
|
|
| 72 |
final_vars=["detect_glorification"],
|
| 73 |
inputs={"project_root": ".", "user_input": request.user_text}
|
| 74 |
)
|
|
|
|
|
|
|
| 75 |
if isinstance(results, DataFrame):
|
| 76 |
results = results.to_dict(orient="dict")
|
| 77 |
return RadicalizationDetectionResponse(values=results)
|
|
|
|
| 84 |
final_vars=["get_enforcement_decision"],
|
| 85 |
inputs={"project_root": ".", "user_input": request.user_text, "violation_context": request.violation_context}
|
| 86 |
)
|
|
|
|
|
|
|
| 87 |
if isinstance(results, DataFrame):
|
| 88 |
results = results.to_dict(orient="dict")
|
| 89 |
return PolicyEnforcementResponse(values=results)
|
|
|
|
| 121 |
|
| 122 |
# Combine the interfaces in a Tabbed interface
|
| 123 |
iface_combined = gr.TabbedInterface([iface, iface2], ["Detect Radicalization", "Policy Enforcement"])
|
| 124 |
+
|
| 125 |
+
if __name__ == "__main__":
|
| 126 |
+
# Launch Gradio interface (no need to launch Uvicorn separately)
|
| 127 |
+
iface_combined.launch(server_name="0.0.0.0", server_port=7860)
|