AnnaAgent-Demo / src /emotion_pertuber.py
sci-m-wang's picture
Upload 14 files
1d4c295 verified
raw
history blame
3.55 kB
import random
# from collections import defaultdict
# 计算总权重
def calculate_total_weight(current_state, states, category_distances, distance_weights):
total_weight = 0
current_class = None
for cls, state_list in states.items():
if current_state in state_list:
current_class = cls
break
if current_class is None:
raise ValueError("Current state not found in any class.")
for cls, state_list in states.items():
distance = category_distances[current_class][cls]
weight = distance_weights.get(distance, 0)
total_weight += weight * len(state_list)
return total_weight
# 计算每个目标状态的概率
def calculate_probabilities(current_state, states, category_distances, distance_weights):
probabilities = {}
current_class = None
for cls, state_list in states.items():
if current_state in state_list:
current_class = cls
break
if current_class is None:
raise ValueError("Current state not found in any class.")
total_weight = calculate_total_weight(current_state, states, category_distances, distance_weights)
for cls, state_list in states.items():
distance = category_distances[current_class][cls]
weight = distance_weights.get(distance, 0)
class_weight = weight * len(state_list)
for state in state_list:
if state != current_state:
probabilities[state] = class_weight / total_weight
return probabilities
# 实现状态扰动
def perturb_state(current_state):
# 定义状态和类别
states = {
'Positive': [
"admiration",
"amusement",
"approval",
"caring",
"curiosity",
"desire",
"excitement",
"gratitude",
"joy",
"love",
"optimism",
"pride",
"realization",
"relief"
],
'Neutral': ['neutral'],
'Ambiguous': [
"confusion",
"disappointment",
"nervousness"
],
'Negative': [
"anger",
"annoyance",
"disapproval",
"disgust",
"embarrassment",
"fear",
"sadness",
"remorse"
]
}
# 定义类别之间的距离
category_distances = {
'Positive': {'Positive': 0, 'Neutral': 1, 'Ambiguous': 2, 'Negative': 3},
'Neutral': {'Positive': 1, 'Neutral': 0, 'Ambiguous': 1, 'Negative': 2},
'Ambiguous': {'Positive': 2, 'Neutral': 1, 'Ambiguous': 0, 'Negative': 1},
'Negative': {'Positive': 3, 'Neutral': 2, 'Ambiguous': 1, 'Negative': 0}
}
# 定义距离权重
distance_weights = {
0: 10, # 同类状态
1: 5, # 相邻类别
2: 2, # 相隔一个类别
3: 1 # 相隔两个类别
}
probabilities = calculate_probabilities(current_state, states, category_distances, distance_weights)
next_state = random.choices(list(probabilities.keys()), weights=list(probabilities.values()), k=1)[0]
return next_state
# 示例运行
# current_state = 'confusion'
# next_state = perturb_state(current_state)
# print(f"Next state: {next_state}")
# 验证概率分布
# state_counts = defaultdict(int)
# for _ in range(1000):
# next_state = perturb_state(current_state, states, category_distances, distance_weights)
# state_counts[next_state] += 1
# print("\nProbability distribution:")
# for state, count in state_counts.items():
# print(f"{state}: {count / 1000:.2f}")