RL / app.py
broadfield-dev's picture
Update app.py
b7ed5ce verified
# app.py
from flask import Flask, jsonify, render_template
import numpy as np
app = Flask(__name__)
# Define the maze
maze = np.array([
[0, 1, 0, 0, 0],
[0, 1, 0, 1, 0],
[0, 0, 0, 1, 0],
[0, 1, 1, 1, 0],
[0, 0, 0, 0, 0]
])
# Define the target
target = (4, 4)
# Initialize the agent's position
agent_position = (0, 0)
# Define the Q-table
q_table = np.zeros((5, 5, 4)) # 5x5 grid, 4 actions (up, down, left, right)
# Define the reward function
def get_reward(agent_position):
if agent_position == target:
return 100
elif maze[agent_position] == 1:
return -10
return -1
# Define the action space
actions = {
0: (-1, 0), # up
1: (1, 0), # down
2: (0, -1), # left
3: (0, 1) # right
}
# Define the training function
def train_q_table(episodes=1000, alpha=0.1, gamma=0.95, epsilon=0.1):
global q_table, agent_position
for _ in range(episodes):
state = (0, 0)
while state != target:
if np.random.uniform(0, 1) < epsilon:
action = np.random.choice([0, 1, 2, 3])
else:
action = np.argmax(q_table[state])
next_state = tuple(np.array(state) + np.array(actions[action]))
if next_state[0] < 0 or next_state[0] >= 5 or next_state[1] < 0 or next_state[1] >= 5:
next_state = state # Stay in the same state if the move is invalid
reward = get_reward(next_state)
old_value = q_table[state + (action,)]
next_max = np.max(q_table[next_state])
new_value = (1 - alpha) * old_value + alpha * (reward + gamma * next_max)
q_table[state + (action,)] = new_value
state = next_state
agent_position = (0, 0)
# Train the Q-table
train_q_table()
@app.route('/')
def index():
return render_template('index.html')
@app.route('/step')
def step():
global agent_position
action = np.argmax(q_table[agent_position])
next_state = tuple(np.array(agent_position) + np.array(actions[action]))
if next_state[0] < 0 or next_state[0] >= 5 or next_state[1] < 0 or next_state[1] >= 5:
next_state = agent_position # Stay in the same state if the move is invalid
agent_position = next_state
# Convert NumPy arrays to lists to ensure JSON serializability
return jsonify({
'agent_position': list(agent_position),
'target': list(target),
'maze': maze.tolist()
})
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860)