|
|
using System; |
|
|
using System.Collections; |
|
|
using System.Collections.Generic; |
|
|
using System.Collections.ObjectModel; |
|
|
using System.Linq; |
|
|
using UnityEngine; |
|
|
|
|
|
[Serializable] |
|
|
public enum Action |
|
|
{ |
|
|
Up, |
|
|
Down, |
|
|
Left, |
|
|
Right, |
|
|
None |
|
|
} |
|
|
[Serializable] |
|
|
public class Agent : MonoBehaviour |
|
|
{ |
|
|
#region Fields |
|
|
[SerializeField] |
|
|
private int _step; |
|
|
[SerializeField] |
|
|
private int _iteration; |
|
|
[SerializeField] |
|
|
private int _currentGridX; |
|
|
[SerializeField] |
|
|
private int _currentGridY; |
|
|
[SerializeField] |
|
|
private (int,int)? _previousState = null; |
|
|
[SerializeField] |
|
|
private Action? _previousAction = null; |
|
|
[SerializeField] |
|
|
private float? _previousReward = null; |
|
|
[SerializeField] |
|
|
private GUIController _gUIController; |
|
|
[SerializeField] |
|
|
[Range(0f, 1f)] |
|
|
private float _learningRate; |
|
|
[SerializeField] |
|
|
[Range(0f, 1f)] |
|
|
private float _discountingFactor; |
|
|
|
|
|
[SerializeField] |
|
|
private int _mimumumStateActionPairFrequencies; |
|
|
[SerializeField] |
|
|
private float _estimatedBestPossibleRewardValue; |
|
|
[SerializeField] |
|
|
private Coroutine _waitThenActionCoroutine; |
|
|
[SerializeField] |
|
|
private bool _isPause; |
|
|
[SerializeField] |
|
|
[Range(0.001f, 30f)] |
|
|
private float _restTime; |
|
|
[SerializeField] |
|
|
private GameObject _roadBlock; |
|
|
[SerializeField] |
|
|
private GameObject _Goodies; |
|
|
|
|
|
public int Step { get => _step; set => _step = value; } |
|
|
public int Iteration { get => _iteration; set => _iteration = value; } |
|
|
public int CurrentGridX { get => _currentGridX; set => _currentGridX = value; } |
|
|
public int CurrentGridY { get => _currentGridY; set => _currentGridY = value; } |
|
|
public (int, int)? PreviousState { get => _previousState; set => _previousState = value; } |
|
|
public Action? PreviousAction { get => _previousAction; set => _previousAction = value; } |
|
|
public float? PreviousReward { get => _previousReward; set => _previousReward = value; } |
|
|
public GUIController GUIController { get => _gUIController; set => _gUIController = value; } |
|
|
public float LearningRate { get => _learningRate; set => _learningRate = value; } |
|
|
public float DiscountingFactor { get => _discountingFactor; set => _discountingFactor = value; } |
|
|
public int MimumumStateActionPairFrequencies { get => _mimumumStateActionPairFrequencies; set => _mimumumStateActionPairFrequencies = value; } |
|
|
public float EstimatedBestPossibleRewardValue { get => _estimatedBestPossibleRewardValue; set => _estimatedBestPossibleRewardValue = value; } |
|
|
public Coroutine WaitThenActionCoroutine { get => _waitThenActionCoroutine; set => _waitThenActionCoroutine = value; } |
|
|
public bool IsPause { get => _isPause; set => _isPause = value; } |
|
|
public float RestTime { get => _restTime; set => _restTime = value; } |
|
|
public GameObject RoadBlock { get => _roadBlock; set => _roadBlock = value; } |
|
|
public GameObject Goodies { get => _Goodies; set => _Goodies = value; } |
|
|
|
|
|
public (int,int) StartState; |
|
|
public (int,int) FinalState = (7,9); |
|
|
|
|
|
public int StartX; |
|
|
public int StartY; |
|
|
|
|
|
public int GrizSizeX; |
|
|
public int GrizSizeY; |
|
|
|
|
|
public Dictionary<((int,int),Action),float> StateActionPairQValue { get; set; } |
|
|
|
|
|
public Dictionary<(int, int), float> StateRewardGrid { get; set; } |
|
|
public Dictionary<Action, System.Action> ActionDelegatesDictonary { get; set; } |
|
|
#endregion |
|
|
|
|
|
#region Q_Learning_Agent |
|
|
private Action Q_Learning_Agent((int,int) currentState, float rewardSignal) |
|
|
{ |
|
|
UpdateStep(); |
|
|
if (PreviousState == FinalState) |
|
|
{ |
|
|
StateActionPairQValue[(PreviousState.Value, Action.None)] = rewardSignal; |
|
|
} |
|
|
|
|
|
if (PreviousState.HasValue) |
|
|
{ |
|
|
((int, int), Action) stateActionPair = (PreviousState.Value, PreviousAction.Value); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
StateActionPairQValue[stateActionPair] += LearningRate * (PreviousReward.Value + (DiscountingFactor * MaxStateActionPairQValue(ref currentState)) - StateActionPairQValue[stateActionPair]); |
|
|
} |
|
|
PreviousState = currentState; |
|
|
PreviousAction = ArgMaxActionExploration(ref currentState); |
|
|
PreviousReward = rewardSignal; |
|
|
return PreviousAction.Value; |
|
|
} |
|
|
|
|
|
|
|
|
private float MaxStateActionPairQValue(ref (int, int) currentState) |
|
|
{ |
|
|
if (currentState == FinalState) |
|
|
return StateActionPairQValue[(currentState, Action.None)]; |
|
|
|
|
|
float max = float.NegativeInfinity; |
|
|
|
|
|
foreach (Action action in SuffledActions()) |
|
|
{ |
|
|
max = Mathf.Max(StateActionPairQValue[(currentState, action)], max); |
|
|
} |
|
|
return max; |
|
|
} |
|
|
|
|
|
private static Action[] SuffledActions() |
|
|
{ |
|
|
Action[] actions = new Action[4]; |
|
|
int i = 0; |
|
|
foreach (Action action in Enum.GetValues(typeof(Action))) |
|
|
{ |
|
|
if (action != Action.None) |
|
|
{ |
|
|
actions[i] = action; |
|
|
i++; |
|
|
} |
|
|
} |
|
|
System.Random random = new System.Random(); |
|
|
return actions.OrderBy(_ => random.Next()).ToArray(); |
|
|
} |
|
|
#region Conflicts with the wall check and out of bound check |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#endregion |
|
|
private Action ArgMaxActionExploration(ref (int, int) currentState) |
|
|
{ |
|
|
if (currentState == FinalState) |
|
|
return Action.None; |
|
|
|
|
|
Action argMaxAction = Action.None; |
|
|
float max = float.NegativeInfinity; |
|
|
|
|
|
foreach (Action action in SuffledActions()) |
|
|
{ |
|
|
float value = StateActionPairQValue[(currentState, action)]; |
|
|
if (value >= max) |
|
|
{ |
|
|
max = value; |
|
|
argMaxAction = action; |
|
|
} |
|
|
} |
|
|
return argMaxAction; |
|
|
} |
|
|
private void Left() |
|
|
{ |
|
|
transform.position -= new Vector3(1f, 0f, 0f); |
|
|
CurrentGridX--; |
|
|
WaitThenActionCoroutine = StartCoroutine(WaitThenAction(RestTime, (CurrentGridX, CurrentGridY))); |
|
|
} |
|
|
|
|
|
private void Right() |
|
|
{ |
|
|
transform.position += new Vector3(1f, 0f, 0f); |
|
|
CurrentGridX++; |
|
|
WaitThenActionCoroutine = StartCoroutine(WaitThenAction(RestTime, (CurrentGridX, CurrentGridY))); |
|
|
} |
|
|
|
|
|
private void Up() |
|
|
{ |
|
|
transform.position += new Vector3(0f, 0f, 1f); |
|
|
CurrentGridY++; |
|
|
WaitThenActionCoroutine = StartCoroutine(WaitThenAction(RestTime, (CurrentGridX, CurrentGridY))); |
|
|
} |
|
|
|
|
|
private void Down() |
|
|
{ |
|
|
transform.position -= new Vector3(0f, 0f, 1f); |
|
|
CurrentGridY--; |
|
|
WaitThenActionCoroutine = StartCoroutine(WaitThenAction(RestTime, (CurrentGridX, CurrentGridY))); |
|
|
} |
|
|
|
|
|
private void None() |
|
|
{ |
|
|
ResetAgentToStart(); |
|
|
UpdateIteration(); |
|
|
WaitThenActionCoroutine = StartCoroutine(WaitThenAction(RestTime, (CurrentGridX, CurrentGridY))); |
|
|
} |
|
|
|
|
|
private void ResetAgentToStart() |
|
|
{ |
|
|
transform.position = new Vector3(StartState.Item1, 1f, StartState.Item2); |
|
|
CurrentGridX = StartState.Item1; |
|
|
CurrentGridY = StartState.Item2; |
|
|
Grid.instance.ClearColors(); |
|
|
} |
|
|
|
|
|
private IEnumerator WaitThenAction(float waitTime, (int,int) GridCoordinate) |
|
|
{ |
|
|
|
|
|
while(IsPause) |
|
|
{ |
|
|
yield return null; |
|
|
} |
|
|
yield return new WaitForSeconds(waitTime); |
|
|
ActionDelegatesDictonary[Q_Learning_Agent(GridCoordinate, StateRewardGrid[GridCoordinate])](); |
|
|
} |
|
|
#endregion |
|
|
|
|
|
#region Unity |
|
|
private void Start() |
|
|
{ |
|
|
FinalState = Grid.instance.goalPosition; |
|
|
|
|
|
ActionDelegatesDictonary = new Dictionary<Action, System.Action>(); |
|
|
ActionDelegatesDictonary[Action.Left] = Left; |
|
|
ActionDelegatesDictonary[Action.Right] = Right; |
|
|
ActionDelegatesDictonary[Action.Up] = Up; |
|
|
ActionDelegatesDictonary[Action.Down] = Down; |
|
|
ActionDelegatesDictonary[Action.None] = None; |
|
|
StartX = UnityEngine.Random.Range(0, GrizSizeX); |
|
|
StartY = UnityEngine.Random.Range(0, GrizSizeY); |
|
|
Initialized(); |
|
|
} |
|
|
|
|
|
private void Initialized() |
|
|
{ |
|
|
PreviousAction = null; |
|
|
PreviousReward = null; |
|
|
PreviousState = null; |
|
|
Step = 0; |
|
|
Iteration = 0; |
|
|
transform.position = new Vector3(StartX, 1f, StartY); |
|
|
StartState = (StartX, StartY); |
|
|
CurrentGridX = StartState.Item1; |
|
|
CurrentGridY = StartState.Item2; |
|
|
StateActionPairQValue = new Dictionary<((int, int), Action), float>(); |
|
|
|
|
|
StateRewardGrid = new Dictionary<(int, int), float>(); |
|
|
|
|
|
for (int i = 0; i < GrizSizeX; i++) |
|
|
{ |
|
|
for (int j = 0; j < GrizSizeY; j++) |
|
|
{ |
|
|
foreach (Action action in Enum.GetValues(typeof(Action))) |
|
|
{ |
|
|
StateActionPairQValue[((i, j), action)] = 0; |
|
|
|
|
|
} |
|
|
StateRewardGrid[(i, j)] = 0f; |
|
|
} |
|
|
} |
|
|
StateRewardGrid[FinalState] = 100f; |
|
|
|
|
|
for (int i = 0; i < GrizSizeX; i++) |
|
|
{ |
|
|
for (int j = 0; j < GrizSizeY; j++) |
|
|
{ |
|
|
if (i != StartState.Item1 && i != FinalState.Item1 && j != StartState.Item2 && j != FinalState.Item2) |
|
|
{ |
|
|
float random = UnityEngine.Random.Range(0f, 1f); |
|
|
if (random <= 0.3f) |
|
|
{ |
|
|
if (random <= 0.2f) |
|
|
{ |
|
|
Instantiate(RoadBlock, new Vector3(i, 0.5f, j), Quaternion.identity); |
|
|
if (i + 1 < GrizSizeX) |
|
|
{ |
|
|
StateActionPairQValue[((i + 1, j), Action.Left)] = float.NegativeInfinity; |
|
|
} |
|
|
if (i - 1 >= 0) |
|
|
{ |
|
|
StateActionPairQValue[((i - 1, j), Action.Right)] = float.NegativeInfinity; |
|
|
} |
|
|
if (j + 1 < GrizSizeY) |
|
|
{ |
|
|
StateActionPairQValue[((i, j + 1), Action.Down)] = float.NegativeInfinity; |
|
|
} |
|
|
if (j - 1 >= 0) |
|
|
{ |
|
|
StateActionPairQValue[((i, j - 1), Action.Up)] = float.NegativeInfinity; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
} |
|
|
if (i == 0 || j == 0 || i == GrizSizeX-1 || j == GrizSizeY-1) |
|
|
{ |
|
|
StateRewardGrid[(i, j)] = 0f; |
|
|
|
|
|
if(i == 0) |
|
|
{ |
|
|
StateActionPairQValue[((i, j), Action.Left)] = float.NegativeInfinity; |
|
|
} |
|
|
if(j == 0) |
|
|
{ |
|
|
StateActionPairQValue[((i, j), Action.Down)] = float.NegativeInfinity; |
|
|
} |
|
|
if(i == GrizSizeX-1) |
|
|
{ |
|
|
StateActionPairQValue[((i, j), Action.Right)] = float.NegativeInfinity; |
|
|
} |
|
|
if(j == GrizSizeY-1) |
|
|
{ |
|
|
StateActionPairQValue[((i, j), Action.Up)] = float.NegativeInfinity; |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
private void ReInitialized() |
|
|
{ |
|
|
PreviousAction = null; |
|
|
PreviousReward = null; |
|
|
PreviousState = null; |
|
|
Step = 0; |
|
|
Iteration = 0; |
|
|
transform.position = new Vector3(StartX, 1f, StartY); |
|
|
StartState = (StartX, StartY); |
|
|
CurrentGridX = StartState.Item1; |
|
|
CurrentGridY = StartState.Item2; |
|
|
|
|
|
|
|
|
for (int i = 0; i < GrizSizeX; i++) |
|
|
{ |
|
|
for (int j = 0; j < GrizSizeY; j++) |
|
|
{ |
|
|
foreach (Action action in Enum.GetValues(typeof(Action))) |
|
|
{ |
|
|
if(!(StateActionPairQValue.ContainsKey(((i, j), action)) && StateActionPairQValue[((i, j), action)] == float.NegativeInfinity)) |
|
|
{ |
|
|
StateActionPairQValue[((i, j), action)] = 0; |
|
|
|
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
private void Update() |
|
|
{ |
|
|
Grid.instance.UpdateColor(CurrentGridX, CurrentGridY); |
|
|
} |
|
|
|
|
|
public void StartExploring() |
|
|
{ |
|
|
UpdateIteration(); |
|
|
WaitThenActionCoroutine = StartCoroutine(WaitThenAction(1f, StartState)); |
|
|
} |
|
|
|
|
|
public void Stop() |
|
|
{ |
|
|
ReInitialized(); |
|
|
StopCoroutine(WaitThenActionCoroutine); |
|
|
} |
|
|
|
|
|
private void UpdateStep() |
|
|
{ |
|
|
Step++; |
|
|
GUIController?.UpdateStepText(Step.ToString()); |
|
|
} |
|
|
|
|
|
private void UpdateIteration() |
|
|
{ |
|
|
Iteration++; |
|
|
GUIController?.UpdateInterationText(Iteration.ToString()); |
|
|
} |
|
|
#endregion |
|
|
} |