from fastapi import FastAPI, HTTPException from pydantic import BaseModel from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline import torch class Guardrail: def __init__(self): tokenizer = AutoTokenizer.from_pretrained("ProtectAI/deberta-v3-base-prompt-injection") model = AutoModelForSequenceClassification.from_pretrained("ProtectAI/deberta-v3-base-prompt-injection") self.classifier = pipeline( "text-classification", model=model, tokenizer=tokenizer, truncation=True, max_length=512, device=torch.device("cuda" if torch.cuda.is_available() else "cpu") ) def guard(self, prompt): return self.classifier(prompt) def determine_level(self, label, score): if label == "SAFE": return 0, "safe" else: if score > 0.9: return 4, "high" elif score > 0.75: return 3, "medium" elif score > 0.5: return 2, "low" else: return 1, "very low" class TextPrompt(BaseModel): prompt: str class ClassificationResult(BaseModel): label: str score: float level: int severity_label: str app = FastAPI() guardrail = Guardrail() @app.post("/classify/", response_model=ClassificationResult) def classify_text(text_prompt: TextPrompt): try: result = guardrail.guard(text_prompt.prompt) label = result[0]['label'] score = result[0]['score'] level, severity_label = guardrail.determine_level(label, score) return {"label": label, "score": score, "level": level, "severity_label": severity_label} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)