File size: 4,283 Bytes
280ee80 cab69db 280ee80 cab69db 280ee80 55c951d 280ee80 cab69db 280ee80 cab69db 280ee80 cab69db |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import gradio as gr
import os
import sys
from fastapi import FastAPI
from pydantic import BaseModel
from hamilton import driver
from pandas import DataFrame
from fastapi.middleware.cors import CORSMiddleware
# Add the src directory to the Python path
sys.path.append(os.path.join(os.path.dirname(__file__), 'src'))
from src.data_module import data_pipeline, embedding_pipeline, vectorstore
from src.classification_module import semantic_similarity, dio_support_detector
from src.enforcement_module import policy_enforcement_decider
from decouple import config
app = FastAPI()
# Enable CORS for Gradio to communicate with FastAPI
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
config = {
"loader": "pd",
"embedding_service": "openai",
"api_key": config("OPENAI_API_KEY"),
"model_name": "text-embedding-ada-002",
"mistral_public_url": config("MISTRAL_PUBLIC_URL"),
"ner_public_url": config("NER_PUBLIC_URL"),
}
dr = (
driver.Builder()
.with_config(config)
.with_modules(data_pipeline, embedding_pipeline, vectorstore, semantic_similarity, dio_support_detector)
.build()
)
dr_enforcement = (
driver.Builder()
.with_config(config)
.with_modules(policy_enforcement_decider)
.build()
)
class RadicalizationDetectionRequest(BaseModel):
user_text: str
class PolicyEnforcementRequest(BaseModel):
user_text: str
violation_context: dict
class RadicalizationDetectionResponse(BaseModel):
values: dict
class PolicyEnforcementResponse(BaseModel):
values: dict
@app.post("/detect_radicalization")
def detect_radicalization(
request: RadicalizationDetectionRequest
) -> RadicalizationDetectionResponse:
results = dr.execute(
final_vars=["detect_glorification"],
inputs={"project_root": ".", "user_input": request.user_text}
)
if isinstance(results, DataFrame):
results = results.to_dict(orient="dict")
return RadicalizationDetectionResponse(values=results)
@app.post("/generate_policy_enforcement")
def generate_policy_enforcement(
request: PolicyEnforcementRequest
) -> PolicyEnforcementResponse:
results = dr_enforcement.execute(
final_vars=["get_enforcement_decision"],
inputs={"project_root": ".", "user_input": request.user_text, "violation_context": request.violation_context}
)
if isinstance(results, DataFrame):
results = results.to_dict(orient="dict")
return PolicyEnforcementResponse(values=results)
# Gradio Interface Functions
def gradio_detect_radicalization(user_text: str):
request = RadicalizationDetectionRequest(user_text=user_text)
response = detect_radicalization(request)
return response.values
def gradio_generate_policy_enforcement(user_text: str, violation_context: str):
# violation_context needs to be provided in a valid JSON format
context_dict = eval(violation_context) # Replace eval with json.loads for safer parsing if it's JSON
request = PolicyEnforcementRequest(user_text=user_text, violation_context=context_dict)
response = generate_policy_enforcement(request)
return response.values
# Define the Gradio interface
iface = gr.Interface(
fn=gradio_detect_radicalization, # Function to detect radicalization
inputs="text", # Single text input
outputs="json", # Return JSON output
title="Radicalization Detection",
description="Enter text to detect glorification or radicalization."
)
# Second interface for policy enforcement
iface2 = gr.Interface(
fn=gradio_generate_policy_enforcement, # Function to generate policy enforcement
inputs=["text", "text"], # Two text inputs, one for user text, one for violation context
outputs="json", # Return JSON output
title="Policy Enforcement Decision",
description="Enter user text and context to generate a policy enforcement decision."
)
# Combine the interfaces in a Tabbed interface
iface_combined = gr.TabbedInterface([iface, iface2], ["Detect Radicalization", "Policy Enforcement"])
if __name__ == "__main__":
# Launch Gradio interface (no need to launch Uvicorn separately)
iface_combined.launch(server_name="0.0.0.0", server_port=7860)
|