Spaces:
Running
Running
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from langchain_groq import ChatGroq | |
from langchain.chains import LLMChain | |
from langchain.prompts import PromptTemplate | |
from transformers import pipeline | |
import os | |
# Initialize FastAPI app | |
app = FastAPI() | |
# Create a request model with context | |
class SearchQuery(BaseModel): | |
query: str | |
context: str = None # Optional context field | |
# Initialize LangChain with Groq | |
llm = ChatGroq( | |
temperature=0.7, | |
model_name="mixtral-8x7b-32768", | |
groq_api_key="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" # Replace with your actual Groq API key | |
) | |
# Define all prompt templates | |
prompt_templates = { | |
"common_threats": PromptTemplate( | |
input_variables=["query", "context"], | |
template=""" | |
Context: {context} | |
Query: {query} | |
Provide a comprehensive overview of the most common cybersecurity threats faced by organizations on a daily basis, including details on threat_1, threat_2, and threat_3. Also, provide effective measures to mitigate these risks and protect critical data and systems. | |
""" | |
), | |
"task_prioritization": PromptTemplate( | |
input_variables=["query", "context"], | |
template=""" | |
Context: {context} | |
Query: {query} | |
Provide a guide on how cybersecurity professionals prioritize their tasks and responsibilities, focusing on the most critical areas such as threat detection, response times, and resource allocation. | |
""" | |
), | |
"network_traffic_tools": PromptTemplate( | |
input_variables=["query", "context"], | |
template=""" | |
Context: {context} | |
Query: {query} | |
List and describe the most effective tools and software used for monitoring network traffic. Include tools for real-time analysis, anomaly detection, and reporting. | |
""" | |
), | |
"vulnerability_assessments": PromptTemplate( | |
input_variables=["query", "context"], | |
template=""" | |
Context: {context} | |
Query: {query} | |
Provide best practices for conducting vulnerability assessments and penetration tests, including recommended frequencies and methodologies to ensure systems are adequately tested for vulnerabilities. | |
""" | |
), | |
"security_policies": PromptTemplate( | |
input_variables=["query", "context"], | |
template=""" | |
Context: {context} | |
Query: {query} | |
Explain the role cybersecurity professionals have in developing, updating, and enforcing security policies within an organization. Include considerations for evolving threats and compliance requirements. | |
""" | |
), | |
"staying_updated": PromptTemplate( | |
input_variables=["query", "context"], | |
template=""" | |
Context: {context} | |
Query: {query} | |
Describe the methods and tools cybersecurity professionals use to stay up-to-date on the latest cybersecurity threats, trends, and vulnerabilities, including ongoing education and industry resources. | |
""" | |
), | |
"immediate_incidents": PromptTemplate( | |
input_variables=["query", "context"], | |
template=""" | |
Context: {context} | |
Query: {query} | |
Identify and describe the types of cybersecurity incidents that require immediate attention, such as data breaches, malware attacks, and denial-of-service attacks. Provide guidance on how to respond to each incident type. | |
""" | |
), | |
"collaboration_it_teams": PromptTemplate( | |
input_variables=["query", "context"], | |
template=""" | |
Context: {context} | |
Query: {query} | |
Discuss how cybersecurity professionals work with IT teams to ensure system security, focusing on areas such as patch management, incident response, and ongoing risk management. | |
""" | |
), | |
"incident_investigation": PromptTemplate( | |
input_variables=["query", "context"], | |
template=""" | |
Context: {context} | |
Query: {query} | |
Outline the steps involved in investigating and resolving a security incident, including initial detection, containment, root cause analysis, and reporting. | |
""" | |
), | |
"securing_remote_workers": PromptTemplate( | |
input_variables=["query", "context"], | |
template=""" | |
Context: {context} | |
Query: {query} | |
Provide strategies for securing remote workers and their devices, including the use of VPNs, multi-factor authentication, and endpoint protection measures. | |
""" | |
), | |
"disaster_recovery": PromptTemplate( | |
input_variables=["query", "context"], | |
template=""" | |
Context: {context} | |
Query: {query} | |
Explain the responsibilities of cybersecurity professionals in ensuring that disaster recovery and business continuity plans are developed, tested, and maintained to address security challenges. | |
""" | |
), | |
"user_access_management": PromptTemplate( | |
input_variables=["query", "context"], | |
template=""" | |
Context: {context} | |
Query: {query} | |
Describe the best practices for managing user access and privileges, including role-based access control (RBAC), least privilege principles, and audit trails for sensitive systems. | |
""" | |
), | |
"cloud_security": PromptTemplate( | |
input_variables=["query", "context"], | |
template=""" | |
Context: {context} | |
Query: {query} | |
Provide a list of best practices for securing cloud-based infrastructure, including the use of strong authentication, data encryption, and continuous monitoring. | |
""" | |
), | |
"security_kpis": PromptTemplate( | |
input_variables=["query", "context"], | |
template=""" | |
Context: {context} | |
Query: {query} | |
Discuss the key performance indicators (KPIs) used by cybersecurity professionals to measure the effectiveness of security programs, such as incident response times, patching cycles, and vulnerability remediation rates. | |
""" | |
), | |
"employee_security_education": PromptTemplate( | |
input_variables=["query", "context"], | |
template=""" | |
Context: {context} | |
Query: {query} | |
Describe the methods used by cybersecurity professionals to educate employees on security best practices, including training programs, phishing simulations, and awareness campaigns. | |
""" | |
), | |
"common_challenges": PromptTemplate( | |
input_variables=["query", "context"], | |
template=""" | |
Context: {context} | |
Query: {query} | |
Identify and discuss the common challenges that cybersecurity professionals face, including resource limitations, evolving threats, and the complexities of compliance. | |
""" | |
), | |
"compliance_standards": PromptTemplate( | |
input_variables=["query", "context"], | |
template=""" | |
Context: {context} | |
Query: {query} | |
Provide an overview of how cybersecurity professionals ensure compliance with industry standards and regulations, such as GDPR, HIPAA, and PCI DSS, including regular audits and reporting. | |
""" | |
), | |
"encryption_role": PromptTemplate( | |
input_variables=["query", "context"], | |
template=""" | |
Context: {context} | |
Query: {query} | |
Explain the role of encryption in protecting sensitive data, focusing on encryption methods, data-at-rest vs. data-in-transit, and how encryption helps mitigate the risks of data breaches. | |
""" | |
), | |
"mobile_device_security": PromptTemplate( | |
input_variables=["query", "context"], | |
template=""" | |
Context: {context} | |
Query: {query} | |
Provide strategies for managing and securing mobile devices and applications, including mobile device management (MDM), app whitelisting, and secure communication methods. | |
""" | |
), | |
"security_audits": PromptTemplate( | |
input_variables=["query", "context"], | |
template=""" | |
Context: {context} | |
Query: {query} | |
Outline the steps involved in conducting security audits and risk assessments, including identifying potential threats, assessing vulnerabilities, and recommending mitigation strategies. | |
""" | |
), | |
"patch_management": PromptTemplate( | |
input_variables=["query", "context"], | |
template=""" | |
Context: {context} | |
Query: {query} | |
Describe the best practices for managing patch updates and ensuring software security, including patch management policies, vulnerability scanning, and prioritizing patches based on risk. | |
""" | |
), | |
"wireless_iot_security": PromptTemplate( | |
input_variables=["query", "context"], | |
template=""" | |
Context: {context} | |
Query: {query} | |
Provide a comprehensive guide on securing wireless networks and IoT devices, including the use of encryption, network segmentation, and regular vulnerability assessments. | |
""" | |
), | |
"general": PromptTemplate( | |
input_variables=["query", "context"], | |
template=""" | |
Context: You are a cybersecurity expert with extensive experience in all sub-streams of the industry, including but not limited to network security, application security, cloud security, threat intelligence, penetration testing, and incident response. {context} | |
Query: {query} | |
Please provide a detailed and professional response to the query based on your expertise in cybersecurity and the provided context. | |
""" | |
), | |
} | |
# Initialize chains for each prompt | |
chains = {key: LLMChain(llm=llm, prompt=prompt) for key, prompt in prompt_templates.items()} | |
# Initialize the zero-shot classifier | |
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli") | |
# Define the possible question types (labels) based on your prompt templates | |
question_types = list(prompt_templates.keys()) | |
# Classifier function using the model | |
def classify_query(query: str) -> str: | |
""" | |
Classify the query using a zero-shot classification model. | |
Returns the most likely question type from the prompt templates. | |
""" | |
try: | |
# Perform zero-shot classification | |
result = classifier(query, candidate_labels=question_types) | |
# Get the label with the highest score | |
predicted_type = result["labels"][0] | |
confidence = result["scores"][0] | |
# If confidence is too low (e.g., < 0.5), fallback to 'general' | |
if confidence < 0.5: | |
print(f"Low confidence ({confidence}) for query '{query}', falling back to 'general'") | |
return "general" | |
return predicted_type | |
except Exception as e: | |
print(f"Error in classification: {e}") | |
return "general" # Fallback to general in case of errors | |
async def process_search(search_query: SearchQuery): | |
try: | |
# Set default context if not provided | |
context = search_query.context or "You are a cybersecurity expert." | |
# Classify the query using the model | |
query_type = classify_query(search_query.query) | |
# Process the query using the appropriate chain | |
if query_type in chains: | |
raw_response = chains[query_type].run(query=search_query.query, context=context) | |
else: | |
raw_response = chains["general"].run(query=search_query.query, context=context) | |
# Structure the response according to the desired format | |
structured_response = { | |
"Clearly articulate your task and desired outcome": f"The task is to address the query: '{search_query.query}'. The desired outcome is a detailed, actionable response.", | |
"Offer relevant background information to guide the AI’s understanding": f"The query was processed with the context: '{context}', guiding the response to align with cybersecurity expertise.", | |
"Use Clear Language: Avoid ambiguity and complex wording": raw_response.strip(), # The raw response from Grok, cleaned up | |
"Experiment with different prompt structures and learn from the results": f"This response uses the '{query_type}' template. Try rephrasing the query for alternative perspectives or more specificity." | |
} | |
return { | |
"status": "success", | |
"response": structured_response, | |
"classified_type": query_type # Optional: return the classified type for debugging | |
} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def root(): | |
return {"message": "Search API with structured response is running"} | |
# Run the app (optional, for local testing) | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) |