security-survey-bot / models /chat_state.py
nyasukun's picture
initial commit for security survey
3966c38
from typing import List, Optional, Union, Dict, Any
from langgraph.graph.message import MessagesState
from langchain_core.messages import AIMessage, HumanMessage
from pydantic import BaseModel
from collections import OrderedDict
import re
class AttackState(MessagesState, total=False):
"""State for the ATT&CK Navigator workflow"""
attack_json: Optional[Dict[str, Any]] = None
scenario: Optional[str] = None
is_valid_context: Optional[bool] = None
extracted_user_scenario: Optional[str] = None
extracted_user_layer_operation: Optional[str] = None
class SecuritySurveyState(MessagesState, total=False):
"""State for the Security Survey workflow"""
security_checklist: OrderedDict = None # 質問リスト
current_part: str = None # 今のパート名
current_question_index: int = 0 # 今のパート内の質問番号
answers: dict = None # 回答を格納
is_survey_complete: bool = False # 全質問終了フラグ
awaiting_clear_answer: bool = False # 明確な回答待ちフラグ
last_question: str = None # 直近の質問内容
expecting_answer: bool = False # 回答待ちフラグ
is_new_session: bool = True # 新規セッションフラグ
def get_initial_state() -> AttackState:
"""Get the initial state for the workflow"""
return AttackState(
messages=[],
attack_json=None,
scenario=None,
is_valid_context=None,
extracted_user_scenario=None,
extracted_user_layer_operation=None
)
def get_security_survey_initial_state(security_checklist: OrderedDict) -> SecuritySurveyState:
"""Get the initial state for the security survey workflow"""
first_part = next(iter(security_checklist))
return SecuritySurveyState(
messages=[],
security_checklist=security_checklist,
current_part=first_part,
current_question_index=0,
answers={part: {} for part in security_checklist},
is_survey_complete=False,
awaiting_clear_answer=False,
last_question=None,
expecting_answer=False,
is_new_session=True
)
def add_user_message(state: AttackState, content: str) -> AttackState:
"""Add a user message to the state"""
state['messages'].append(HumanMessage(content=content))
return state
def add_ai_message(state: AttackState, content: str) -> AttackState:
"""Add an AI message to the state"""
state['messages'].append(AIMessage(content=content))
return state
def set_attack_json(state: AttackState, attack_json: Dict[str, Any]) -> AttackState:
"""Set the ATT&CK JSON in the state"""
state['attack_json'] = attack_json
return state
def set_scenario(state: AttackState, scenario: str) -> AttackState:
"""Set the scenario text in the state"""
state['scenario'] = scenario
return state
def set_valid_context(state: AttackState, is_valid: bool) -> AttackState:
"""Set the context validity in the state"""
state['is_valid_context'] = is_valid
return state
def evaluate_answer(answer: str) -> Optional[bool]:
"""ユーザの回答からTrue/Falseを判定する。判断できない場合はNoneを返す"""
positive_patterns = [
r'(はい|イエス|yes|hai|true|正しい|実施|行[っな]て|対策済み|している|いる|やっている|やってる|対応している|対策している|やっております|している)',
r'導入しています',
r'設定しています',
r'確認しています',
r'実施しています'
]
negative_patterns = [
r'(いいえ|ノー|no|iie|false|違う|違います|してない|していない|いない|やっていない|対応していない|対策していない|行[っな]ていない)',
r'導入していません',
r'設定していません',
r'確認していません',
r'実施していません'
]
answer = answer.lower()
for pattern in positive_patterns:
if re.search(pattern, answer):
return True
for pattern in negative_patterns:
if re.search(pattern, answer):
return False
return None # 判断できない場合
def process_answer(state: SecuritySurveyState, answer: str) -> SecuritySurveyState:
"""ユーザの回答を処理する"""
if not state.get('expecting_answer', False):
# 回答待ちでなければ何もしない
return state
part = state['current_part']
idx = state['current_question_index']
questions = list(state['security_checklist'][part].keys())
question = questions[idx]
# 回答を評価
answer_value = evaluate_answer(answer)
if answer_value is None:
# 曖昧な回答の場合
state['awaiting_clear_answer'] = True
else:
# 明確な回答の場合
state['awaiting_clear_answer'] = False
state['answers'][part][question] = answer_value
# 次の質問インデックスを設定
if idx + 1 < len(questions):
state['current_question_index'] += 1
else:
# 次のパートへ
parts = list(state['security_checklist'].keys())
current_part_idx = parts.index(part)
if current_part_idx + 1 < len(parts):
state['current_part'] = parts[current_part_idx + 1]
state['current_question_index'] = 0
else:
# 全て終了
state['is_survey_complete'] = True
# 回答待ちフラグをオフ
state['expecting_answer'] = False
return state
def get_next_question(state: SecuritySurveyState) -> Optional[str]:
"""次の質問を取得する。全て終了している場合はNoneを返す"""
if state['is_survey_complete']:
return None
part = state['current_part']
idx = state['current_question_index']
questions = list(state['security_checklist'][part].keys())
question_text = questions[idx]
# 明確な回答が必要な場合は追加メッセージをつける
if state.get('awaiting_clear_answer', False):
full_question = f"【{part}】\n{question_text}\n\n※「はい」か「いいえ」ではっきりお答えください"
else:
full_question = f"【{part}】\n{question_text}"
# 質問を保存
state['last_question'] = full_question
state['expecting_answer'] = True
return full_question
def has_unanswered_questions(state: SecuritySurveyState) -> bool:
"""未回答の質問があるかチェック"""
return not state['is_survey_complete']