Spaces:
Sleeping
Sleeping
from typing import List, Optional, Union, Dict, Any | |
from langgraph.graph.message import MessagesState | |
from langchain_core.messages import AIMessage, HumanMessage | |
from pydantic import BaseModel | |
from collections import OrderedDict | |
import re | |
class AttackState(MessagesState, total=False): | |
"""State for the ATT&CK Navigator workflow""" | |
attack_json: Optional[Dict[str, Any]] = None | |
scenario: Optional[str] = None | |
is_valid_context: Optional[bool] = None | |
extracted_user_scenario: Optional[str] = None | |
extracted_user_layer_operation: Optional[str] = None | |
class SecuritySurveyState(MessagesState, total=False): | |
"""State for the Security Survey workflow""" | |
security_checklist: OrderedDict = None # 質問リスト | |
current_part: str = None # 今のパート名 | |
current_question_index: int = 0 # 今のパート内の質問番号 | |
answers: dict = None # 回答を格納 | |
is_survey_complete: bool = False # 全質問終了フラグ | |
awaiting_clear_answer: bool = False # 明確な回答待ちフラグ | |
last_question: str = None # 直近の質問内容 | |
expecting_answer: bool = False # 回答待ちフラグ | |
is_new_session: bool = True # 新規セッションフラグ | |
def get_initial_state() -> AttackState: | |
"""Get the initial state for the workflow""" | |
return AttackState( | |
messages=[], | |
attack_json=None, | |
scenario=None, | |
is_valid_context=None, | |
extracted_user_scenario=None, | |
extracted_user_layer_operation=None | |
) | |
def get_security_survey_initial_state(security_checklist: OrderedDict) -> SecuritySurveyState: | |
"""Get the initial state for the security survey workflow""" | |
first_part = next(iter(security_checklist)) | |
return SecuritySurveyState( | |
messages=[], | |
security_checklist=security_checklist, | |
current_part=first_part, | |
current_question_index=0, | |
answers={part: {} for part in security_checklist}, | |
is_survey_complete=False, | |
awaiting_clear_answer=False, | |
last_question=None, | |
expecting_answer=False, | |
is_new_session=True | |
) | |
def add_user_message(state: AttackState, content: str) -> AttackState: | |
"""Add a user message to the state""" | |
state['messages'].append(HumanMessage(content=content)) | |
return state | |
def add_ai_message(state: AttackState, content: str) -> AttackState: | |
"""Add an AI message to the state""" | |
state['messages'].append(AIMessage(content=content)) | |
return state | |
def set_attack_json(state: AttackState, attack_json: Dict[str, Any]) -> AttackState: | |
"""Set the ATT&CK JSON in the state""" | |
state['attack_json'] = attack_json | |
return state | |
def set_scenario(state: AttackState, scenario: str) -> AttackState: | |
"""Set the scenario text in the state""" | |
state['scenario'] = scenario | |
return state | |
def set_valid_context(state: AttackState, is_valid: bool) -> AttackState: | |
"""Set the context validity in the state""" | |
state['is_valid_context'] = is_valid | |
return state | |
def evaluate_answer(answer: str) -> Optional[bool]: | |
"""ユーザの回答からTrue/Falseを判定する。判断できない場合はNoneを返す""" | |
positive_patterns = [ | |
r'(はい|イエス|yes|hai|true|正しい|実施|行[っな]て|対策済み|している|いる|やっている|やってる|対応している|対策している|やっております|している)', | |
r'導入しています', | |
r'設定しています', | |
r'確認しています', | |
r'実施しています' | |
] | |
negative_patterns = [ | |
r'(いいえ|ノー|no|iie|false|違う|違います|してない|していない|いない|やっていない|対応していない|対策していない|行[っな]ていない)', | |
r'導入していません', | |
r'設定していません', | |
r'確認していません', | |
r'実施していません' | |
] | |
answer = answer.lower() | |
for pattern in positive_patterns: | |
if re.search(pattern, answer): | |
return True | |
for pattern in negative_patterns: | |
if re.search(pattern, answer): | |
return False | |
return None # 判断できない場合 | |
def process_answer(state: SecuritySurveyState, answer: str) -> SecuritySurveyState: | |
"""ユーザの回答を処理する""" | |
if not state.get('expecting_answer', False): | |
# 回答待ちでなければ何もしない | |
return state | |
part = state['current_part'] | |
idx = state['current_question_index'] | |
questions = list(state['security_checklist'][part].keys()) | |
question = questions[idx] | |
# 回答を評価 | |
answer_value = evaluate_answer(answer) | |
if answer_value is None: | |
# 曖昧な回答の場合 | |
state['awaiting_clear_answer'] = True | |
else: | |
# 明確な回答の場合 | |
state['awaiting_clear_answer'] = False | |
state['answers'][part][question] = answer_value | |
# 次の質問インデックスを設定 | |
if idx + 1 < len(questions): | |
state['current_question_index'] += 1 | |
else: | |
# 次のパートへ | |
parts = list(state['security_checklist'].keys()) | |
current_part_idx = parts.index(part) | |
if current_part_idx + 1 < len(parts): | |
state['current_part'] = parts[current_part_idx + 1] | |
state['current_question_index'] = 0 | |
else: | |
# 全て終了 | |
state['is_survey_complete'] = True | |
# 回答待ちフラグをオフ | |
state['expecting_answer'] = False | |
return state | |
def get_next_question(state: SecuritySurveyState) -> Optional[str]: | |
"""次の質問を取得する。全て終了している場合はNoneを返す""" | |
if state['is_survey_complete']: | |
return None | |
part = state['current_part'] | |
idx = state['current_question_index'] | |
questions = list(state['security_checklist'][part].keys()) | |
question_text = questions[idx] | |
# 明確な回答が必要な場合は追加メッセージをつける | |
if state.get('awaiting_clear_answer', False): | |
full_question = f"【{part}】\n{question_text}\n\n※「はい」か「いいえ」ではっきりお答えください" | |
else: | |
full_question = f"【{part}】\n{question_text}" | |
# 質問を保存 | |
state['last_question'] = full_question | |
state['expecting_answer'] = True | |
return full_question | |
def has_unanswered_questions(state: SecuritySurveyState) -> bool: | |
"""未回答の質問があるかチェック""" | |
return not state['is_survey_complete'] |