Spaces:
Runtime error
Runtime error
| from functools import partial | |
| import math | |
| import chex | |
| import jax | |
| import jax.numpy as jnp | |
| from flax.serialization import to_state_dict | |
| from jax2d.engine import ( | |
| PhysicsEngine, | |
| calculate_collision_matrix, | |
| calc_inverse_mass_polygon, | |
| calc_inverse_mass_circle, | |
| calc_inverse_inertia_circle, | |
| calc_inverse_inertia_polygon, | |
| recalculate_mass_and_inertia, | |
| select_shape, | |
| ) | |
| from jax2d.sim_state import SimState, RigidBody, Joint, Thruster | |
| from jax2d.maths import rmat | |
| from kinetix.environment.env_state import EnvParams, EnvState, StaticEnvParams | |
| from kinetix.environment.ued.ued_state import UEDParams | |
| from kinetix.environment.ued.util import ( | |
| count_roles, | |
| is_space_for_joint, | |
| make_velocities_zero, | |
| sample_dimensions, | |
| random_position_on_polygon, | |
| random_position_on_circle, | |
| get_role, | |
| is_space_for_shape, | |
| are_there_shapes_present, | |
| ) | |
| from kinetix.util.saving import load_world_state_pickle | |
| from flax import struct | |
| from kinetix.environment.env import create_empty_env | |
| from kinetix.environment.ued.util import make_do_dummy_step | |
| def mutate_add_shape( | |
| rng, | |
| state: EnvState, | |
| params: EnvParams, | |
| static_env_params: StaticEnvParams, | |
| ued_params: UEDParams, | |
| force_no_fixate: bool = False, | |
| ): | |
| def do_dummy(rng, state): | |
| return state | |
| def do_add(rng, state): | |
| rng, _rng = jax.random.split(rng) | |
| _rngs = jax.random.split(_rng, 9) | |
| space_for_new_rect = state.polygon.active.astype(int).sum() < static_env_params.num_polygons | |
| space_for_new_circle = state.circle.active.astype(int).sum() < static_env_params.num_circles | |
| is_rect_p = jnp.array([space_for_new_rect * 1.0, space_for_new_circle * 1.0]) | |
| is_rect = jax.random.choice(_rngs[0], jnp.array([True, False], dtype=bool), p=is_rect_p) | |
| rect_index = jnp.argmin(state.polygon.active) | |
| circle_index = jnp.argmin(state.circle.active) | |
| shape_role = get_role(_rngs[1], state, static_env_params) | |
| max_shape_size = ( | |
| jnp.array([1.0, ued_params.goal_body_size_factor, ued_params.goal_body_size_factor, 1.0])[shape_role] | |
| * ued_params.max_shape_size | |
| ) | |
| vertices, half_dimensions, radius = sample_dimensions( | |
| _rngs[2], | |
| static_env_params, | |
| is_rect, | |
| ued_params, | |
| max_shape_size=max_shape_size, | |
| ) | |
| n_vertices = jax.lax.select(ued_params.generate_triangles, jax.random.choice(_rngs[3], jnp.array([3, 4])), 4) | |
| largest = jnp.max(jnp.array([half_dimensions[0] * jnp.sqrt(2), half_dimensions[1] * jnp.sqrt(2), radius])) | |
| screen_dim_world = ( | |
| static_env_params.screen_dim[0] / params.pixels_per_unit, | |
| static_env_params.screen_dim[1] / params.pixels_per_unit, | |
| ) | |
| min_x = largest | |
| max_x = screen_dim_world[0] - largest | |
| min_y = largest + 0.4 | |
| max_y = screen_dim_world[1] - largest | |
| def _og_minmax(): | |
| return min_x, max_x, min_y, max_y | |
| def _opposite_minmax(): | |
| return jax.lax.switch( | |
| shape_role, | |
| [ | |
| (lambda: (min_x, max_x, min_y, max_y)), | |
| (lambda: (min_x, max_x - screen_dim_world[0] / 2, min_y, max_y)), | |
| (lambda: (min_x + screen_dim_world[0] / 2, max_x, min_y, max_y)), | |
| (lambda: (min_x, max_x, min_y, max_y)), | |
| ], | |
| ) | |
| min_x, max_x, min_y, max_y = jax.lax.cond( | |
| jax.random.uniform(_rngs[4], shape=()) < ued_params.goal_body_opposide_side_chance, | |
| _opposite_minmax, | |
| _og_minmax, | |
| ) | |
| position = jax.random.uniform(_rngs[5], shape=(2,)) * jnp.array( | |
| [ | |
| max_x - min_x, | |
| max_y - min_y, | |
| ] | |
| ) + jnp.array([min_x, min_y]) | |
| rotation = jax.random.uniform(_rngs[6], shape=()) * 2 * math.pi | |
| velocity = jnp.array([0.0, 0.0]) | |
| angular_velocity = 0.0 | |
| density = 1.0 | |
| inverse_mass = jax.lax.select( | |
| is_rect, | |
| calc_inverse_mass_polygon(vertices, n_vertices, static_env_params, density)[0], | |
| calc_inverse_mass_circle(radius, density), | |
| ) | |
| inverse_inertia = jax.lax.select( | |
| is_rect, | |
| calc_inverse_inertia_polygon(vertices, n_vertices, static_env_params, density), | |
| calc_inverse_inertia_circle(radius, density), | |
| ) | |
| fixate_chance = ued_params.fixate_chance_min + (1.0 / inverse_mass) * ued_params.fixate_chance_scale | |
| fixate_chance = jnp.minimum(fixate_chance, ued_params.fixate_chance_max) | |
| is_fixated = jax.random.uniform(_rngs[7], shape=()) < fixate_chance | |
| is_fixated &= ~force_no_fixate | |
| inverse_mass *= 1 - is_fixated | |
| inverse_inertia *= 1 - is_fixated | |
| # We want to bias fixated shapes to starting nearer the bottom half of the screen | |
| fixate_shape_bottom_bias = ( | |
| ued_params.fixate_shape_bottom_bias + ued_params.fixate_shape_bottom_bias_special_role * (shape_role != 0) | |
| ) | |
| is_forcing_bottom = jax.random.uniform(_rngs[8]) < fixate_shape_bottom_bias | |
| half_screen_height = (static_env_params.screen_dim[1] / params.pixels_per_unit) / 2.0 | |
| position = jax.lax.select( | |
| is_fixated & is_forcing_bottom & (position[1] >= half_screen_height), | |
| position.at[1].add(-half_screen_height), | |
| position, | |
| ) | |
| # This could be either a rect or a circle | |
| new_rigid_body = RigidBody( | |
| position=position, | |
| velocity=velocity, | |
| inverse_mass=inverse_mass, | |
| inverse_inertia=inverse_inertia, | |
| rotation=rotation, | |
| angular_velocity=angular_velocity, | |
| radius=radius, | |
| active=True, | |
| friction=1.0, | |
| vertices=vertices, | |
| n_vertices=n_vertices, | |
| collision_mode=1, | |
| restitution=0.0, | |
| ) | |
| state = state.replace( | |
| polygon=jax.tree.map( | |
| lambda x, y: jax.lax.select(is_rect, y.at[rect_index].set(x), y), new_rigid_body, state.polygon | |
| ), | |
| circle=jax.tree.map( | |
| lambda x, y: jax.lax.select(jnp.logical_not(is_rect), y.at[circle_index].set(x), y), | |
| new_rigid_body, | |
| state.circle, | |
| ), | |
| polygon_shape_roles=jax.lax.select( | |
| is_rect, | |
| state.polygon_shape_roles.at[rect_index].set(shape_role), | |
| state.polygon_shape_roles, | |
| ), | |
| circle_shape_roles=jax.lax.select( | |
| jnp.logical_not(is_rect), | |
| state.circle_shape_roles.at[circle_index].set(shape_role), | |
| state.circle_shape_roles, | |
| ), | |
| ) | |
| return recalculate_mass_and_inertia(state, static_env_params, state.polygon_densities, state.circle_densities) | |
| return jax.lax.cond(is_space_for_shape(state), do_add, do_dummy, rng, state) | |
| def mutate_add_connected_shape( | |
| rng, | |
| state: EnvState, | |
| params: EnvParams, | |
| static_env_params: StaticEnvParams, | |
| ued_params: UEDParams, | |
| force_rjoint: bool = False, | |
| ): | |
| def do_dummy(rng, state): | |
| return state, False | |
| def do_add(rng, state): | |
| rng, _rng = jax.random.split(rng) | |
| _rngs = jax.random.split(_rng, 21) | |
| # Select a random index amongst the currently active shapes. | |
| p_rect = state.polygon.active.at[: static_env_params.num_static_fixated_polys].set(False) | |
| p_circle = state.circle.active | |
| p_rect = p_rect.astype(jnp.float32) | |
| p_circle = p_circle.astype(jnp.float32) | |
| p_rect *= (state.polygon.inverse_mass == 0) * ued_params.connect_to_fixated_prob_coeff + ( | |
| state.polygon.inverse_mass != 0 | |
| ) * 1.0 | |
| p_circle *= (state.circle.inverse_mass == 0) * ued_params.connect_to_fixated_prob_coeff + ( | |
| state.circle.inverse_mass != 0 | |
| ) * 1.0 | |
| # Bias based on number of existing connections | |
| rect_connections = jnp.zeros(static_env_params.num_polygons) | |
| circle_connections = jnp.zeros(static_env_params.num_circles) | |
| rect_connections = rect_connections.at[state.joint.a_index].add( | |
| jnp.ones(static_env_params.num_joints) | |
| * state.joint.active | |
| * (state.joint.a_index < static_env_params.num_polygons) | |
| ) | |
| rect_connections = rect_connections.at[state.joint.b_index].add( | |
| jnp.ones(static_env_params.num_joints) | |
| * state.joint.active | |
| * (state.joint.b_index < static_env_params.num_polygons) | |
| ) | |
| circle_connections = circle_connections.at[state.joint.a_index - static_env_params.num_polygons].add( | |
| jnp.ones(static_env_params.num_joints) | |
| * state.joint.active | |
| * (state.joint.a_index >= static_env_params.num_polygons) | |
| ) | |
| circle_connections = circle_connections.at[state.joint.b_index - static_env_params.num_polygons].add( | |
| jnp.ones(static_env_params.num_joints) | |
| * state.joint.active | |
| * (state.joint.b_index >= static_env_params.num_polygons) | |
| ) | |
| # Rectangles can have up to 2 connections | |
| p_rect *= (-rect_connections + 2.0) / 2.0 | |
| p_rect = jnp.maximum(p_rect, 0.0) | |
| # Circles can have 1 connection | |
| p_circle *= circle_connections == 0 | |
| # To sample a target rect/circle, we have to have at least one. | |
| target_rect_p = jnp.array( | |
| [ | |
| (state.polygon.active.astype(int).sum() > static_env_params.num_static_fixated_polys) * 1.0, | |
| (state.circle.active.astype(int).sum() > 0) * 1.0, | |
| ] | |
| ) | |
| # Don't connect to a circle if no connection-free ones exist | |
| target_rect_p = target_rect_p.at[1].mul(p_circle.sum() > 0) | |
| space_for_new_rect = state.polygon.active.astype(int).sum() < static_env_params.num_polygons | |
| space_for_new_circle = state.circle.active.astype(int).sum() < static_env_params.num_circles | |
| is_target_rect = jax.random.choice(_rngs[0], jnp.array([True, False], dtype=bool), p=target_rect_p) | ( | |
| ~space_for_new_rect | |
| ) | |
| is_rect_p = jnp.array([space_for_new_rect * 1.0, space_for_new_circle * 1.0]) | |
| is_rect = jax.random.choice(_rngs[1], jnp.array([True, False], dtype=bool), p=is_rect_p) | ( | |
| ~is_target_rect & space_for_new_rect | |
| ) | |
| shape_index = jax.lax.select( | |
| is_rect, | |
| jnp.argmin(state.polygon.active), | |
| jnp.argmin(state.circle.active), | |
| ) | |
| unified_shape_index = shape_index + (~is_rect) * static_env_params.num_polygons | |
| vertices, half_dimensions, radius = sample_dimensions( | |
| _rngs[2], static_env_params, is_rect, ued_params, max_shape_size=ued_params.max_shape_size | |
| ) | |
| n_vertices = jax.lax.select(ued_params.generate_triangles, jax.random.choice(_rngs[3], jnp.array([3, 4])), 4) | |
| rotation = jax.random.uniform(_rngs[4], shape=()) * 2 * math.pi | |
| velocity = jnp.array([0.0, 0.0]) | |
| angular_velocity = 0.0 | |
| density = 1.0 | |
| inverse_mass = jax.lax.select( | |
| is_rect, | |
| calc_inverse_mass_polygon(vertices, n_vertices, static_env_params, density)[0], | |
| calc_inverse_mass_circle(radius, density), | |
| ) | |
| inverse_inertia = jax.lax.select( | |
| is_rect, | |
| calc_inverse_inertia_polygon(vertices, n_vertices, static_env_params, density), | |
| calc_inverse_inertia_circle(radius, density), | |
| ) | |
| # Joint | |
| current_num_rjoints = (jnp.logical_not(state.joint.is_fixed_joint) * state.joint.active).sum() | |
| is_rjoint = jnp.logical_or( | |
| jnp.logical_or(jax.random.uniform(_rngs[5]) < 0.5, force_rjoint), | |
| current_num_rjoints < ued_params.min_rjoints_bias, | |
| ) | |
| joint_index = jnp.argmin(state.joint.active) | |
| local_joint_position_rect = random_position_on_polygon(_rngs[6], vertices, n_vertices, static_env_params) | |
| local_joint_position_circle = random_position_on_circle(_rngs[7], radius, on_centre_chance=1.0) | |
| local_joint_position = jax.lax.select(is_rect, local_joint_position_rect, local_joint_position_circle) | |
| p_rect = jax.lax.select(p_rect.sum() == 0, state.polygon.active.astype(jnp.float32), p_rect) | |
| p_circle = jax.lax.select(p_circle.sum() == 0, state.circle.active.astype(jnp.float32), p_circle) | |
| target_index = jax.lax.select( | |
| is_target_rect, | |
| jax.random.choice( | |
| _rngs[8], | |
| jnp.arange(static_env_params.num_polygons), | |
| p=p_rect, | |
| ), | |
| jax.random.choice( | |
| _rngs[9], | |
| jnp.arange(static_env_params.num_circles), | |
| p=p_circle, | |
| ), | |
| ) | |
| unified_target_index = target_index + jnp.logical_not(is_target_rect) * static_env_params.num_polygons | |
| target_shape = select_shape(state, unified_target_index, static_env_params) | |
| target_joint_position_rect = random_position_on_polygon( | |
| _rngs[10], state.polygon.vertices[target_index], state.polygon.n_vertices[target_index], static_env_params | |
| ) | |
| target_joint_position_circle = random_position_on_circle( | |
| _rngs[11], state.circle.radius[target_index], on_centre_chance=1.0 | |
| ) | |
| target_joint_position = jax.lax.select(is_target_rect, target_joint_position_rect, target_joint_position_circle) | |
| # Calculate the world position of the new shape | |
| # We know the rotation of the new shape. We also know the position of the current shape, which we want to remain fixed. | |
| # Set `position` such that local_joint_position is the same as `target_joint_position` | |
| global_joint_pos = target_shape.position + jnp.matmul(rmat(target_shape.rotation), target_joint_position) | |
| position = global_joint_pos - jnp.matmul(rmat(rotation), local_joint_position) | |
| _, pos_diff = calc_inverse_mass_polygon(vertices, n_vertices, static_env_params, density) | |
| position = jax.lax.select(is_rect, position + pos_diff, position) | |
| local_joint_position = jax.lax.select(is_rect, local_joint_position - pos_diff, local_joint_position) | |
| vertices = jax.lax.select(is_rect, vertices - pos_diff[None], vertices) | |
| target_role = jax.lax.select( | |
| is_target_rect, state.polygon_shape_roles[target_index], state.circle_shape_roles[target_index] | |
| ) | |
| # We cannot have role 1 and role 2 being connected. | |
| p = jnp.array([1.0, 1.0, 1.0, 1.0]) | |
| # If role is 0, keep all probs at 1, otherwise set the target role's complement to 0 prob | |
| # 3 - role turns 1 to 2 and 2 to 1 | |
| # If the target role is three, we set everything to zero except for the default | |
| p = jax.lax.select( | |
| target_role == 0, | |
| p, | |
| jax.lax.select( | |
| target_role <= 2, | |
| p.at[3 - target_role].set(False).at[3].set(False), | |
| (p.at[2].set(False).at[1].set(False)), | |
| ), | |
| ) | |
| shape_role = get_role(_rngs[12], state, static_env_params, initial_p=p) | |
| # This could be either a rect or a circle | |
| new_rigid_body = RigidBody( | |
| position=position, | |
| velocity=velocity, | |
| inverse_mass=inverse_mass, | |
| inverse_inertia=inverse_inertia, | |
| rotation=rotation, | |
| angular_velocity=angular_velocity, | |
| radius=radius, | |
| active=True, | |
| friction=1.0, | |
| vertices=vertices, | |
| n_vertices=n_vertices, | |
| collision_mode=1, | |
| restitution=0.0, | |
| ) | |
| # Change the shape indices such that a_index is less than b_index | |
| a_index = shape_index + (1 - is_rect) * static_env_params.num_polygons | |
| b_index = target_index + (1 - is_target_rect) * static_env_params.num_polygons | |
| should_swap = a_index > b_index | |
| a_index, b_index, local_joint_position, target_joint_position, shape_a, shape_b = jax.lax.cond( | |
| should_swap, | |
| lambda x: (x[1], x[0], x[3], x[2], x[5], x[4]), # pairwise swap | |
| lambda x: x, | |
| (a_index, b_index, local_joint_position, target_joint_position, new_rigid_body, target_shape), | |
| ) | |
| motor_on = jax.random.uniform(_rngs[13], shape=()) < ued_params.motor_on_chance | |
| joint_colour = jax.random.randint(_rngs[14], shape=(), minval=0, maxval=static_env_params.num_motor_bindings) | |
| joint_rotation = shape_b.rotation - shape_a.rotation | |
| motor_speed = jax.random.uniform( | |
| _rngs[15], shape=(), minval=ued_params.motor_min_speed, maxval=ued_params.motor_max_speed | |
| ) | |
| motor_power = jax.random.uniform( | |
| _rngs[16], shape=(), minval=ued_params.motor_min_power, maxval=ued_params.motor_max_power | |
| ) | |
| wheel_power = jax.random.uniform( | |
| _rngs[20], shape=(), minval=ued_params.motor_min_power, maxval=ued_params.wheel_max_power | |
| ) | |
| # High-powered wheels break the physics engine - this is a temporary fix | |
| motor_power = jax.lax.select(is_rect & is_target_rect, motor_power, wheel_power) | |
| motor_has_joint_limits = jax.random.uniform(_rngs[17], shape=()) < ued_params.joint_limit_chance | |
| motor_has_joint_limits &= is_rect & is_target_rect | |
| joint_limit_min = ( | |
| jax.random.uniform(_rngs[18], shape=(), minval=-ued_params.joint_limit_max, maxval=0.0) | |
| * motor_has_joint_limits | |
| ) | |
| joint_limit_max = ( | |
| jax.random.uniform(_rngs[19], shape=(), minval=0.0, maxval=ued_params.joint_limit_max) | |
| * motor_has_joint_limits | |
| ) | |
| rjoint = Joint( | |
| a_index=a_index, | |
| b_index=b_index, | |
| a_relative_pos=local_joint_position, | |
| b_relative_pos=target_joint_position, | |
| global_position=global_joint_pos, | |
| active=True, | |
| motor_speed=motor_speed, | |
| motor_power=motor_power, | |
| motor_on=motor_on, | |
| # colour=joint_colour, | |
| motor_has_joint_limits=motor_has_joint_limits, | |
| min_rotation=joint_limit_min, | |
| max_rotation=joint_limit_max, | |
| is_fixed_joint=False, | |
| rotation=0.0, | |
| acc_impulse=jnp.zeros((2,), dtype=jnp.float32), | |
| acc_r_impulse=jnp.zeros((), dtype=jnp.float32), | |
| ) | |
| fjoint = Joint( | |
| a_index=a_index, | |
| b_index=b_index, | |
| a_relative_pos=local_joint_position, | |
| b_relative_pos=target_joint_position, | |
| global_position=global_joint_pos, | |
| active=True, | |
| rotation=joint_rotation, | |
| acc_impulse=jnp.zeros((2,), dtype=jnp.float32), | |
| acc_r_impulse=jnp.zeros((), dtype=jnp.float32), | |
| is_fixed_joint=True, | |
| motor_has_joint_limits=False, | |
| min_rotation=0.0, | |
| max_rotation=0.0, | |
| motor_on=False, | |
| motor_power=0.0, | |
| motor_speed=0.0, | |
| ) | |
| state = state.replace( | |
| polygon=jax.tree.map( | |
| lambda x, y: jax.lax.select(is_rect, y.at[shape_index].set(x), y), new_rigid_body, state.polygon | |
| ), | |
| circle=jax.tree.map( | |
| lambda x, y: jax.lax.select(jnp.logical_not(is_rect), y.at[shape_index].set(x), y), | |
| new_rigid_body, | |
| state.circle, | |
| ), | |
| joint=jax.tree.map( | |
| lambda rj, fj, y: jax.lax.select(is_rjoint, y.at[joint_index].set(rj), y.at[joint_index].set(fj)), | |
| rjoint, | |
| fjoint, | |
| state.joint, | |
| ), | |
| polygon_shape_roles=jax.lax.select( | |
| is_rect, | |
| state.polygon_shape_roles.at[shape_index].set(shape_role), | |
| state.polygon_shape_roles, | |
| ), | |
| circle_shape_roles=jax.lax.select( | |
| jnp.logical_not(is_rect), | |
| state.circle_shape_roles.at[shape_index].set(shape_role), | |
| state.circle_shape_roles, | |
| ), | |
| motor_bindings=state.motor_bindings.at[joint_index].set(joint_colour), | |
| ) | |
| # We need the new collision matrix. | |
| state = state.replace(collision_matrix=calculate_collision_matrix(static_env_params, state.joint)) | |
| state = recalculate_mass_and_inertia(state, static_env_params, state.polygon_densities, state.circle_densities) | |
| # Was this a valid addition? | |
| # We calculate whether (assuming the possiblity of 360 degree rotation around the joint) | |
| # both shapes can be visible | |
| # This is to remove the common degenerate pattern of connected shapes being fully inside each other | |
| def _get_min_rect_dist(r_id, local_pos): | |
| rect: RigidBody = jax.tree.map(lambda x: x[r_id], state.polygon) | |
| half_width = (jnp.max(rect.vertices[:, 0]) - jnp.min(rect.vertices[:, 0])) / 2.0 | |
| half_height = (jnp.max(rect.vertices[:, 1]) - jnp.min(rect.vertices[:, 1])) / 2.0 | |
| dist_x = half_width - jnp.abs(local_pos[0]) | |
| dist_y = half_height - jnp.abs(local_pos[1]) | |
| return jnp.minimum(dist_x, dist_y) | |
| def _get_max_rect_dist(r_id, local_pos): | |
| rect: RigidBody = jax.tree.map(lambda x: x[r_id], state.polygon) | |
| half_width = (jnp.max(rect.vertices[:, 0]) - jnp.min(rect.vertices[:, 0])) / 2.0 | |
| half_height = (jnp.max(rect.vertices[:, 1]) - jnp.min(rect.vertices[:, 1])) / 2.0 | |
| dist_x = jnp.maximum( | |
| jnp.abs(half_width - local_pos[0]), | |
| jnp.abs(-half_width - local_pos[0]), | |
| ) | |
| dist_y = jnp.maximum( | |
| jnp.abs(half_height - local_pos[1]), | |
| jnp.abs(-half_height - local_pos[1]), | |
| ) | |
| return jnp.sqrt(dist_x * dist_x + dist_y * dist_y) | |
| def are_both_shapes_showing(idx1, idx2, local_pos1, local_pos2): | |
| def _is_small_shape_showing(small_idx, big_idx, small_local_pos, big_local_pos): | |
| small_is_poly = small_idx < static_env_params.num_polygons | |
| big_is_poly = big_idx < static_env_params.num_polygons | |
| # CC | |
| cc_result = False | |
| # CR | |
| cr_r_dist = _get_min_rect_dist(big_idx, big_local_pos) | |
| cr_result = ( | |
| cr_r_dist + ued_params.connect_visibility_min | |
| < state.circle.radius[small_idx - static_env_params.num_polygons] | |
| ) | |
| # RC | |
| rc_r_dist = _get_max_rect_dist(small_idx, small_local_pos) | |
| rc_result = ( | |
| rc_r_dist | |
| > state.circle.radius[big_idx - static_env_params.num_polygons] + ued_params.connect_visibility_min | |
| ) | |
| # RR | |
| rr_small_dist = _get_max_rect_dist(small_idx, small_local_pos) | |
| rr_big_dist = _get_min_rect_dist(big_idx, big_local_pos) | |
| rr_result = rr_small_dist > rr_big_dist + ued_params.connect_visibility_min | |
| # Select | |
| return jax.lax.select( | |
| small_is_poly, | |
| jax.lax.select(big_is_poly, rr_result, rc_result), | |
| jax.lax.select(big_is_poly, cr_result, cc_result), | |
| ) | |
| # Are both shapes showing? | |
| return _is_small_shape_showing(idx1, idx2, local_pos1, local_pos2) & _is_small_shape_showing( | |
| idx2, idx1, local_pos2, local_pos1 | |
| ) | |
| valid = are_both_shapes_showing( | |
| unified_shape_index, unified_target_index, local_joint_position, target_joint_position | |
| ) | |
| return state, valid | |
| # To add a connected shape, we must have both at least one existing shape and space | |
| return jax.lax.cond( | |
| is_space_for_shape(state) & are_there_shapes_present(state, static_env_params) & is_space_for_joint(state), | |
| do_add, | |
| do_dummy, | |
| rng, | |
| state, | |
| ) | |
| def mutate_add_connected_shape_proper( | |
| rng, | |
| state: EnvState, | |
| params: EnvParams, | |
| static_env_params: StaticEnvParams, | |
| ued_params: UEDParams, | |
| force_rjoint: bool = False, | |
| ): | |
| return mutate_add_connected_shape(rng, state, params, static_env_params, ued_params, force_rjoint=force_rjoint)[0] | |
| def mutate_remove_shape( | |
| rng, state: EnvState, params: EnvParams, static_env_params: StaticEnvParams, ued_params: UEDParams | |
| ): | |
| can_remove_mask = ( | |
| jnp.concatenate([state.polygon.active, state.circle.active]) | |
| .at[: static_env_params.num_static_fixated_polys] | |
| .set(False) | |
| ) | |
| def dummy(rng, state): | |
| return state | |
| def do_remove(rng, state: EnvState): | |
| rng, _rng = jax.random.split(rng) | |
| rngs = jax.random.split(_rng, 2) | |
| p = can_remove_mask.astype(jnp.float32) | |
| index_to_remove = jax.random.choice(rngs[0], jnp.arange(can_remove_mask.shape[0]), p=p) | |
| is_rect = index_to_remove < static_env_params.num_polygons | |
| state = state.replace( | |
| polygon=state.polygon.replace( | |
| active=jax.lax.select( | |
| is_rect, state.polygon.active.at[index_to_remove].set(False), state.polygon.active | |
| ) | |
| ), | |
| circle=state.circle.replace( | |
| active=jax.lax.select( | |
| jnp.logical_not(is_rect), | |
| state.circle.active.at[index_to_remove - static_env_params.num_polygons].set(False), | |
| state.circle.active, | |
| ) | |
| ), | |
| ) | |
| # We need to now remove any joints connected to this shape | |
| joints_to_remove = (state.joint.a_index == index_to_remove) | (state.joint.b_index == index_to_remove) | |
| thrusters_to_remove = state.thruster.object_index == index_to_remove | |
| state = state.replace( | |
| joint=state.joint.replace(active=jnp.where(joints_to_remove, False, state.joint.active)), | |
| thruster=state.thruster.replace(active=jnp.where(thrusters_to_remove, False, state.thruster.active)), | |
| ) | |
| # Now recalculate collision matrix | |
| state = state.replace(collision_matrix=calculate_collision_matrix(static_env_params, state.joint)) | |
| return state | |
| return jax.lax.cond(can_remove_mask.sum() > 0, do_remove, dummy, rng, state) | |
| def mutate_remove_joint( | |
| rng, state: EnvState, params: EnvParams, static_env_params: StaticEnvParams, ued_params: UEDParams | |
| ): | |
| can_remove_mask = state.joint.active | |
| def dummy(rng, state): | |
| return state | |
| def do_remove(rng, state): | |
| rng, _rng = jax.random.split(rng) | |
| rngs = jax.random.split(_rng, 2) | |
| p = can_remove_mask.astype(jnp.float32) | |
| index_to_remove = jax.random.choice(rngs[0], jnp.arange(can_remove_mask.shape[0]), p=p) | |
| state = state.replace(joint=state.joint.replace(active=state.joint.active.at[index_to_remove].set(False))) | |
| # Recalculate collision matrix. | |
| state = state.replace(collision_matrix=calculate_collision_matrix(static_env_params, state.joint)) | |
| return state | |
| return jax.lax.cond(can_remove_mask.sum() > 0, do_remove, dummy, rng, state) | |
| def mutate_swap_role( | |
| rng, state: EnvState, params: EnvParams, static_env_params: StaticEnvParams, ued_params: UEDParams | |
| ): | |
| def _cr(*args): | |
| return count_roles(*args, include_static_polys=False) | |
| role_counts = jax.vmap(_cr, (None, None, 0))(state, static_env_params, jnp.arange(4)) | |
| are_there_multiple_roles = (role_counts > 0).sum() > 1 | |
| def dummy(rng, state): | |
| return state | |
| def do_swap(rng, state): | |
| rng, _rng = jax.random.split(rng) | |
| rngs = jax.random.split(_rng, 2) | |
| all_roles = jnp.concatenate([state.polygon_shape_roles, state.circle_shape_roles]) | |
| p = ( | |
| (jnp.concatenate([state.polygon.active, state.circle.active])) | |
| .astype(jnp.float32) | |
| .at[: static_env_params.num_static_fixated_polys] | |
| .set(0.0) | |
| ) | |
| shape_idx_a = jax.random.choice( | |
| rngs[0], jnp.arange(static_env_params.num_polygons + static_env_params.num_circles), p=p | |
| ) | |
| role_a = all_roles[shape_idx_a] | |
| p = jnp.where(all_roles == role_a, 0.0, p) | |
| shape_idx_b = jax.random.choice( | |
| rngs[1], jnp.arange(static_env_params.num_polygons + static_env_params.num_circles), p=p | |
| ) | |
| role_b = all_roles[shape_idx_b] | |
| role_a, role_b = role_b, role_a | |
| for idx, role in [(shape_idx_a, role_a), (shape_idx_b, role_b)]: | |
| is_rect = idx < static_env_params.num_polygons | |
| state = state.replace( | |
| polygon_shape_roles=jax.lax.select( | |
| is_rect, state.polygon_shape_roles.at[idx].set(role), state.polygon_shape_roles | |
| ), | |
| circle_shape_roles=jax.lax.select( | |
| jnp.logical_not(is_rect), | |
| state.circle_shape_roles.at[idx - static_env_params.num_polygons].set(role), | |
| state.circle_shape_roles, | |
| ), | |
| ) | |
| return state | |
| return jax.lax.cond(are_there_multiple_roles, do_swap, dummy, rng, state) | |
| def mutate_toggle_fixture( | |
| rng, state: EnvState, params: EnvParams, static_env_params: StaticEnvParams, ued_params: UEDParams | |
| ): | |
| can_toggle_mask = ( | |
| jnp.concatenate([state.polygon.active, state.circle.active]) | |
| .at[: static_env_params.num_static_fixated_polys] | |
| .set(False) | |
| ) | |
| def dummy(rng, state): | |
| return state | |
| def do_toggle(rng, state: EnvState): | |
| rng, _rng = jax.random.split(rng) | |
| rngs = jax.random.split(_rng, 2) | |
| p = can_toggle_mask.astype(jnp.float32) | |
| index_to_remove = jax.random.choice(rngs[0], jnp.arange(can_toggle_mask.shape[0]), p=p) | |
| is_rect = index_to_remove < static_env_params.num_polygons | |
| is_current_fixed = ( | |
| jax.lax.select( | |
| is_rect, | |
| state.polygon.inverse_inertia[index_to_remove], | |
| state.circle.inverse_inertia[index_to_remove - static_env_params.num_polygons], | |
| ) | |
| == 0.0 | |
| ) | |
| is_current_fixed = is_current_fixed * 1.0 # if it is fixed, we set it to 1.0 and recalc. | |
| # If it is not fixed, this is 0.0, and it makes it fixed. | |
| state = state.replace( | |
| polygon=state.polygon.replace( | |
| inverse_inertia=jax.lax.select( | |
| is_rect, | |
| state.polygon.inverse_inertia.at[index_to_remove].set(is_current_fixed), | |
| state.polygon.inverse_inertia, | |
| ), | |
| inverse_mass=jax.lax.select( | |
| is_rect, | |
| state.polygon.inverse_mass.at[index_to_remove].set(is_current_fixed), | |
| state.polygon.inverse_mass, | |
| ), | |
| ), | |
| circle=state.circle.replace( | |
| inverse_inertia=jax.lax.select( | |
| jnp.logical_not(is_rect), | |
| state.circle.inverse_inertia.at[index_to_remove - static_env_params.num_polygons].set( | |
| is_current_fixed | |
| ), | |
| state.circle.inverse_inertia, | |
| ), | |
| inverse_mass=jax.lax.select( | |
| jnp.logical_not(is_rect), | |
| state.circle.inverse_mass.at[index_to_remove - static_env_params.num_polygons].set( | |
| is_current_fixed | |
| ), | |
| state.circle.inverse_mass, | |
| ), | |
| ), | |
| ) | |
| state = recalculate_mass_and_inertia(state, static_env_params, state.polygon_densities, state.circle_densities) | |
| return state | |
| return jax.lax.cond(can_toggle_mask.sum() > 0, do_toggle, dummy, rng, state) | |
| def mutate_add_thruster( | |
| rng, state: EnvState, params: EnvParams, static_env_params: StaticEnvParams, ued_params: UEDParams | |
| ): | |
| is_fixated = jnp.concatenate([state.polygon.inverse_mass == 0, state.circle.inverse_mass == 0]) | |
| # is_fixated = jnp.zeros_like(is_fixated, dtype=bool) | |
| is_active = jnp.concatenate([state.polygon.active, state.circle.active]) | |
| can_add_mask = is_active & (~is_fixated) | |
| can_add_mask = jnp.logical_and(is_active, jnp.logical_not(is_fixated)) | |
| def dummy(rng, state): | |
| return state | |
| def do_add(rng, state: EnvState): | |
| rng, _rng = jax.random.split(rng) | |
| _rngs = jax.random.split(_rng, 10) | |
| p = can_add_mask.astype(jnp.float32) | |
| shape_index = jax.random.choice(_rngs[0], jnp.arange(can_add_mask.shape[0]), p=p) | |
| thruster_idx = jnp.argmin(state.thruster.active) | |
| shape = select_shape(state, shape_index, static_env_params) | |
| position_to_add_thruster = jax.lax.select( | |
| shape_index < static_env_params.num_polygons, | |
| random_position_on_polygon(_rngs[1], shape.vertices, shape.n_vertices, static_env_params), | |
| random_position_on_circle(_rngs[2], shape.radius, on_centre_chance=0.0), | |
| ) | |
| direction_to_com = ((jax.random.uniform(_rngs[3]) > 0.5) * 2 - 1) * position_to_add_thruster | |
| direction_to_com = jax.lax.select( | |
| jnp.linalg.norm(direction_to_com) == 0.0, jnp.array([1.0, 0.0]), direction_to_com | |
| ) | |
| thruster_angle = jax.lax.select( | |
| jax.random.uniform(_rngs[4]) < ued_params.thruster_align_com_prob, | |
| jnp.atan2(direction_to_com[1], direction_to_com[0]), # test this | |
| jax.random.uniform( | |
| _rngs[5], | |
| (), | |
| ) | |
| * 2 | |
| * jnp.pi, | |
| ) | |
| thruster_power = jax.random.uniform(_rngs[6]) * 1.5 + 0.5 | |
| thruster = Thruster( | |
| object_index=shape_index, | |
| active=True, | |
| relative_position=position_to_add_thruster, # jnp.array([0.0, 0.0]), # a bit of a hack but reasonable. | |
| rotation=thruster_angle, # jax.random.choice(rngs[1], jnp.arange(4) * jnp.pi / 2), | |
| power=1.0 | |
| / jax.lax.select(shape.inverse_mass == 0, 1.0, shape.inverse_mass) | |
| * ued_params.thruster_power_multiplier | |
| * thruster_power, | |
| global_position=shape.position + jnp.matmul(rmat(shape.rotation), position_to_add_thruster), | |
| ) | |
| thruster_colour = jax.random.randint( | |
| _rngs[7], shape=(), minval=0, maxval=static_env_params.num_thruster_bindings | |
| ) | |
| state = state.replace( | |
| thruster=jax.tree_map(lambda y, x: y.at[thruster_idx].set(x), state.thruster, thruster), | |
| thruster_bindings=state.thruster_bindings.at[thruster_idx].set(thruster_colour), | |
| ) | |
| return state | |
| return jax.lax.cond( | |
| jnp.logical_and((can_add_mask.sum() > 0), (jnp.logical_not(state.thruster.active).sum() > 0)), | |
| do_add, | |
| dummy, | |
| rng, | |
| state, | |
| ) | |
| def mutate_change_gravity( | |
| rng, state: EnvState, params: EnvParams, static_env_params: StaticEnvParams, ued_params: UEDParams | |
| ): | |
| rng, _rng = jax.random.split(rng) | |
| rngs = jax.random.split(_rng, 2) | |
| new_gravity = jax.lax.select( | |
| jax.random.uniform(rngs[0]) < 0.5, | |
| jnp.array([0.0, -9.8]), | |
| jnp.array([0.0, jax.random.uniform(rngs[1], minval=-9.8, maxval=0)]), | |
| ) | |
| return state.replace(gravity=new_gravity) | |
| def mutate_remove_thruster( | |
| rng, state: EnvState, params: EnvParams, static_env_params: StaticEnvParams, ued_params: UEDParams | |
| ): | |
| are_there_thrusters = state.thruster.active | |
| def dummy(rng, state): | |
| return state | |
| def do_remove(rng, state): | |
| rng, _rng = jax.random.split(rng) | |
| rngs = jax.random.split(_rng, 2) | |
| p = are_there_thrusters.astype(jnp.float32) | |
| thruster_idx = jax.random.choice(rngs[0], jnp.arange(are_there_thrusters.shape[0]), p=p) | |
| return state.replace(thruster=state.thruster.replace(active=state.thruster.active.at[thruster_idx].set(False))) | |
| return jax.lax.cond(are_there_thrusters.sum() > 0, do_remove, dummy, rng, state) | |
| def make_mutate_change_shape_size(params, static_env_params): | |
| do_dummy_step = make_do_dummy_step(params, static_env_params) | |
| def mutate_change_shape_size( | |
| rng, state: EnvState, params: EnvParams, static_env_params: StaticEnvParams, ued_params: UEDParams | |
| ): | |
| shape_active = jnp.concatenate( | |
| [state.polygon.active.at[: static_env_params.num_static_fixated_polys].set(False), state.circle.active] | |
| ) | |
| def dummy(rng, state): | |
| return state | |
| def do_change(rng, state): | |
| rng, _rng = jax.random.split(rng) | |
| rngs = jax.random.split(_rng, 10) | |
| p = shape_active.astype(jnp.float32) | |
| shape_idx = jax.random.choice(rngs[0], jnp.arange(shape_active.shape[0]), p=p) | |
| is_rect = shape_idx < static_env_params.num_polygons | |
| vertices, _, radius = sample_dimensions( | |
| rngs[1], static_env_params, is_rect, ued_params, max_shape_size=ued_params.max_shape_size | |
| ) | |
| idx_new_top_left = jnp.argmin(vertices[:, 0] * 100 + vertices[:, 1]) | |
| idx_old_top_left = jnp.argmin( | |
| state.polygon.vertices[shape_idx, :, 0] * 100 + state.polygon.vertices[shape_idx, :, 1] | |
| ) | |
| scale_rect = (vertices[idx_new_top_left]) / (state.polygon.vertices[shape_idx, idx_old_top_left]) | |
| scale_circle = radius / state.circle.radius[shape_idx - static_env_params.num_polygons] | |
| vertices = state.polygon.vertices[shape_idx] * scale_rect | |
| scale = jax.lax.select( | |
| is_rect, | |
| scale_rect, | |
| jnp.array([scale_circle, scale_circle]), | |
| ) | |
| is_a = ((state.joint.a_index == shape_idx) & state.joint.active)[:, None] | |
| is_b = ((state.joint.b_index == shape_idx) & state.joint.active)[:, None] | |
| state = state.replace( | |
| joint=state.joint.replace( | |
| a_relative_pos=(state.joint.a_relative_pos * scale[None]) * is_a | |
| + (1 - is_a) * state.joint.a_relative_pos, | |
| b_relative_pos=(state.joint.b_relative_pos * scale[None]) * is_b | |
| + (1 - is_b) * state.joint.b_relative_pos, | |
| ), | |
| polygon=state.polygon.replace( | |
| vertices=jax.lax.select( | |
| is_rect, state.polygon.vertices.at[shape_idx].set(vertices), state.polygon.vertices | |
| ), | |
| ), | |
| circle=state.circle.replace( | |
| radius=jax.lax.select( | |
| jnp.logical_not(is_rect), | |
| state.circle.radius.at[shape_idx - static_env_params.num_polygons].set(radius), | |
| state.circle.radius, | |
| ) | |
| ), | |
| ) | |
| def _ss(state, _): | |
| return do_dummy_step(state), None | |
| state = jax.lax.scan(_ss, state, jnp.arange(5))[0] | |
| return recalculate_mass_and_inertia( | |
| state, static_env_params, state.polygon_densities, state.circle_densities | |
| ) | |
| return jax.lax.cond(shape_active.sum() > 0, do_change, dummy, rng, state) | |
| return mutate_change_shape_size | |
| def mutate_change_shape_location( | |
| rng, state: EnvState, params: EnvParams, static_env_params: StaticEnvParams, ued_params: UEDParams | |
| ): | |
| shape_active = jnp.concatenate( | |
| [state.polygon.active.at[: static_env_params.num_static_fixated_polys].set(False), state.circle.active] | |
| ) | |
| def dummy(rng, state): | |
| return state | |
| def do_change(rng, state): | |
| rng, _rng = jax.random.split(rng) | |
| rngs = jax.random.split(_rng, 10) | |
| p = shape_active.astype(jnp.float32) | |
| shape_idx = jax.random.choice(rngs[0], jnp.arange(shape_active.shape[0]), p=p) | |
| delta_pos = jax.random.uniform(rngs[1], shape=(2,)) - 0.5 # [-0.5, 0.5] | |
| positions = jnp.concatenate([state.polygon.position, state.circle.position]) | |
| mask_of_shape_locations_to_change = ( | |
| (state.collision_matrix[shape_idx] == 0).at[: static_env_params.num_static_fixated_polys].set(False) | |
| ) | |
| # check the new positions, but then maybe revert if any shape becomes out of bounds now. | |
| new_positions_tentative = positions * ( | |
| 1 - mask_of_shape_locations_to_change[:, None] | |
| ) + mask_of_shape_locations_to_change[:, None] * (positions + delta_pos[None]) | |
| polys = state.polygon | |
| p_pos = new_positions_tentative[: static_env_params.num_polygons] | |
| c_pos = new_positions_tentative[static_env_params.num_polygons :] # state.circle.position | |
| rad = state.circle.radius | |
| rect_vertex_mask = jnp.arange(static_env_params.max_polygon_vertices)[None] < polys.n_vertices[:, None] | |
| rect_mask = polys.active.at[: static_env_params.num_static_fixated_polys].set(False) | |
| circ_mask = state.circle.active | |
| # check if new pos maybe goes out of bounds: | |
| min_x, max_x, min_y, max_y = ( | |
| jnp.minimum( | |
| jnp.min( | |
| p_pos[:, 0] + jnp.min(polys.vertices[:, :, 0], where=rect_vertex_mask, initial=0, axis=1), | |
| where=rect_mask, | |
| initial=jnp.inf, | |
| ), | |
| jnp.min(c_pos[:, 0] - rad, where=circ_mask, initial=jnp.inf), | |
| ), | |
| jnp.maximum( | |
| jnp.max( | |
| p_pos[:, 0] + jnp.max(polys.vertices[:, :, 0], where=rect_vertex_mask, initial=0, axis=1), | |
| where=rect_mask, | |
| initial=-jnp.inf, | |
| ), | |
| jnp.max(c_pos[:, 0] + rad, where=circ_mask, initial=-jnp.inf), | |
| ), | |
| jnp.minimum( | |
| jnp.min( | |
| p_pos[:, 1] + jnp.min(polys.vertices[:, :, 1], where=rect_vertex_mask, initial=0, axis=1), | |
| where=rect_mask, | |
| initial=jnp.inf, | |
| ), | |
| jnp.min(c_pos[:, 1] - rad, where=circ_mask, initial=jnp.inf), | |
| ), | |
| jnp.maximum( | |
| jnp.max( | |
| p_pos[:, 1] + jnp.max(polys.vertices[:, :, 1], where=rect_vertex_mask, initial=0, axis=1), | |
| where=rect_mask, | |
| initial=-jnp.inf, | |
| ), | |
| jnp.max(c_pos[:, 1] + rad, where=circ_mask, initial=-jnp.inf), | |
| ), | |
| ) | |
| how_much_oob_x_left = jnp.maximum(0, 0 - min_x) | |
| how_much_oob_x_right = jnp.maximum(0, max_x - static_env_params.screen_dim[0] / params.pixels_per_unit) | |
| how_much_oob_y_down = jnp.maximum(0, 0.4 - min_y) # this is for the floor | |
| how_much_oob_y_up = jnp.maximum(0, max_y - static_env_params.screen_dim[1] / params.pixels_per_unit) | |
| # correct by out of bounds factor | |
| positions = ( | |
| new_positions_tentative | |
| + jnp.array( | |
| [ | |
| how_much_oob_x_left - how_much_oob_x_right, | |
| how_much_oob_y_down - how_much_oob_y_up, | |
| ] | |
| )[None] | |
| * mask_of_shape_locations_to_change[:, None] | |
| ) | |
| state = state.replace( | |
| polygon=state.polygon.replace( | |
| position=positions[: static_env_params.num_polygons], | |
| ), | |
| circle=state.circle.replace( | |
| position=positions[static_env_params.num_polygons :], | |
| ), | |
| ) | |
| return recalculate_mass_and_inertia(state, static_env_params, state.polygon_densities, state.circle_densities) | |
| return jax.lax.cond(shape_active.sum() > 0, do_change, dummy, rng, state) | |
| def make_mutate_change_shape_rotation(params, static_env_params): | |
| do_dummy_step = make_do_dummy_step(params, static_env_params) | |
| def mutate_change_shape_rotation( | |
| rng, state: EnvState, params: EnvParams, static_env_params: StaticEnvParams, ued_params: UEDParams | |
| ): | |
| shape_active = jnp.concatenate( | |
| [state.polygon.active.at[: static_env_params.num_static_fixated_polys].set(False), state.circle.active] | |
| ) | |
| def dummy(rng, state): | |
| return state | |
| def do_change(rng, state): | |
| rng, _rng = jax.random.split(rng) | |
| rngs = jax.random.split(_rng, 10) | |
| p = shape_active.astype(jnp.float32) | |
| shape_idx = jax.random.choice(rngs[0], jnp.arange(shape_active.shape[0]), p=p) | |
| is_rect = shape_idx < static_env_params.num_polygons | |
| rotation_delta = jax.random.uniform(rngs[1], shape=()) * math.pi / 2 | |
| has_fixed_joint_a = (state.joint.a_index == shape_idx) & state.joint.is_fixed_joint & state.joint.active | |
| has_fixed_joint_b = (state.joint.b_index == shape_idx) & state.joint.is_fixed_joint & state.joint.active | |
| state = state.replace( | |
| joint=state.joint.replace( | |
| rotation=jax.lax.select( | |
| has_fixed_joint_a, | |
| state.joint.rotation - rotation_delta, | |
| jax.lax.select( | |
| has_fixed_joint_b, | |
| state.joint.rotation + rotation_delta, | |
| state.joint.rotation, | |
| ), | |
| ) | |
| ), | |
| polygon=state.polygon.replace( | |
| rotation=jax.lax.select( | |
| is_rect, state.polygon.rotation.at[shape_idx].add(rotation_delta), state.polygon.rotation | |
| ), | |
| ), | |
| circle=state.circle.replace( | |
| rotation=jax.lax.select( | |
| jnp.logical_not(is_rect), | |
| state.circle.rotation.at[shape_idx - static_env_params.num_polygons].add(rotation_delta), | |
| state.circle.rotation, | |
| ) | |
| ), | |
| ) | |
| def _ss(state, _): | |
| return do_dummy_step(state), None | |
| state = jax.lax.scan(_ss, state, jnp.arange(5))[0] | |
| return recalculate_mass_and_inertia( | |
| state, static_env_params, state.polygon_densities, state.circle_densities | |
| ) | |
| return jax.lax.cond(shape_active.sum() > 0, do_change, dummy, rng, state) | |
| return mutate_change_shape_rotation | |