Skip to content

Commit

Permalink
Merge pull request #49 from jjshoots/pole_env
Browse files Browse the repository at this point in the history
Add pole balancing environments
  • Loading branch information
jjshoots committed Jun 30, 2024
2 parents 1c8074b + bad2839 commit be0a846
Show file tree
Hide file tree
Showing 63 changed files with 3,499 additions and 1,483 deletions.
8 changes: 8 additions & 0 deletions PyFlyt/gym_envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@
id="PyFlyt/QuadX-Gates-v2",
entry_point="PyFlyt.gym_envs.quadx_envs.quadx_gates_env:QuadXGatesEnv",
)
register(
id="PyFlyt/QuadX-Pole-Balance-v2",
entry_point="PyFlyt.gym_envs.quadx_envs.quadx_pole_balance_env:QuadXPoleBalanceEnv",
)
register(
id="PyFlyt/QuadX-Pole-Waypoints-v2",
entry_point="PyFlyt.gym_envs.quadx_envs.quadx_pole_waypoints_env:QuadXPoleWaypointsEnv",
)

# Fixedwing Envs
register(
Expand Down
6 changes: 3 additions & 3 deletions PyFlyt/gym_envs/fixedwing_envs/fixedwing_base_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,10 +231,10 @@ def compute_attitude(
lin_vel = raw_state[2]
lin_pos = raw_state[3]

# quarternion angles
quarternion = p.getQuaternionFromEuler(ang_pos)
# quaternion angles
quaternion = p.getQuaternionFromEuler(ang_pos)

return ang_vel, ang_pos, lin_vel, lin_pos, quarternion
return ang_vel, ang_pos, lin_vel, lin_pos, quaternion

def compute_term_trunc_reward(self) -> None:
"""compute_term_trunc_reward."""
Expand Down
29 changes: 23 additions & 6 deletions PyFlyt/gym_envs/fixedwing_envs/fixedwing_waypoints_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def __init__(
goal_reach_distance=goal_reach_distance,
goal_reach_angle=np.inf,
flight_dome_size=flight_dome_size,
min_height=0.5,
np_random=self.np_random,
)

Expand Down Expand Up @@ -134,22 +135,38 @@ def compute_state(self) -> None:
- "target_deltas" (Sequence)
----- list of body_frame distances to target (vector of 3/4 values)
"""
ang_vel, ang_pos, lin_vel, lin_pos, quarternion = super().compute_attitude()
ang_vel, ang_pos, lin_vel, lin_pos, quaternion = super().compute_attitude()
aux_state = super().compute_auxiliary()

# combine everything
new_state: dict[Literal["attitude", "target_deltas"], np.ndarray] = dict()
if self.angle_representation == 0:
new_state["attitude"] = np.array(
[*ang_vel, *ang_pos, *lin_vel, *lin_pos, *self.action, *aux_state]
new_state["attitude"] = np.concatenate(
[
ang_vel,
ang_pos,
lin_vel,
lin_pos,
self.action,
aux_state,
],
axis=-1,
)
elif self.angle_representation == 1:
new_state["attitude"] = np.array(
[*ang_vel, *quarternion, *lin_vel, *lin_pos, *self.action, *aux_state]
new_state["attitude"] = np.concatenate(
[
ang_vel,
quaternion,
lin_vel,
lin_pos,
self.action,
aux_state,
],
axis=-1,
)

new_state["target_deltas"] = self.waypoints.distance_to_target(
ang_pos, lin_pos, quarternion
ang_pos, lin_pos, quaternion
)
self.distance_to_immediate = float(
np.linalg.norm(new_state["target_deltas"][0])
Expand Down
12 changes: 6 additions & 6 deletions PyFlyt/gym_envs/quadx_envs/quadx_base_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def compute_attitude(
- ang_pos (vector of 3 values)
- lin_vel (vector of 3 values)
- lin_pos (vector of 3 values)
- quarternion (vector of 4 values)
- quaternion (vector of 4 values)
"""
raw_state = self.env.state(0)

Expand All @@ -235,10 +235,10 @@ def compute_attitude(
lin_vel = raw_state[2]
lin_pos = raw_state[3]

# quarternion angles
quarternion = p.getQuaternionFromEuler(ang_pos)
# quaternion angles
quaternion = p.getQuaternionFromEuler(ang_pos)

return ang_vel, ang_pos, lin_vel, lin_pos, quarternion
return ang_vel, ang_pos, lin_vel, lin_pos, quaternion

def compute_term_trunc_reward(self) -> None:
"""compute_term_trunc_reward."""
Expand All @@ -250,8 +250,8 @@ def compute_base_term_trunc_reward(self) -> None:
if self.step_count > self.max_steps:
self.truncation |= True

# collision
if np.any(self.env.contact_array):
# if anything hits the floor, basically game over
if np.any(self.env.contact_array[self.env.planeId]):
self.reward = -100.0
self.info["collision"] = True
self.termination |= True
Expand Down
12 changes: 6 additions & 6 deletions PyFlyt/gym_envs/quadx_envs/quadx_gates_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,11 +252,11 @@ def compute_state(self) -> None:
- "target_deltas" (Graph)
- list of body_frame distances to target (vector of 3/4 values)
"""
ang_vel, ang_pos, lin_vel, lin_pos, quarternion = super().compute_attitude()
ang_vel, ang_pos, lin_vel, lin_pos, quaternion = super().compute_attitude()
aux_state = super().compute_auxiliary()

# rotation matrix
rotation = np.array(p.getMatrixFromQuaternion(quarternion)).reshape(3, 3).T
rotation = np.array(p.getMatrixFromQuaternion(quaternion)).reshape(3, 3).T

# drone to target
target_deltas = np.matmul(rotation, (self.targets - lin_pos).T).T
Expand All @@ -267,12 +267,12 @@ def compute_state(self) -> None:
Literal["attitude", "rgba_cam", "target_deltas"], np.ndarray
] = dict()
if self.angle_representation == 0:
new_state["attitude"] = np.array(
[*ang_vel, *ang_pos, *lin_vel, *lin_pos, *self.action, *aux_state]
new_state["attitude"] = np.concatenate(
[ang_vel, ang_pos, lin_vel, lin_pos, self.action, aux_state], axis=-1
)
elif self.angle_representation == 1:
new_state["attitude"] = np.array(
[*ang_vel, *quarternion, *lin_vel, *lin_pos, *self.action, *aux_state]
new_state["attitude"] = np.concatenate(
[ang_vel, quaternion, lin_vel, lin_pos, self.action, aux_state], axis=-1
)

# grab the image
Expand Down
18 changes: 13 additions & 5 deletions PyFlyt/gym_envs/quadx_envs/quadx_hover_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,17 +95,25 @@ def compute_state(self) -> None:
- previous_action (vector of 4 values)
- auxiliary information (vector of 4 values)
"""
ang_vel, ang_pos, lin_vel, lin_pos, quarternion = super().compute_attitude()
ang_vel, ang_pos, lin_vel, lin_pos, quaternion = super().compute_attitude()
aux_state = super().compute_auxiliary()

# combine everything
if self.angle_representation == 0:
self.state = np.array(
[*ang_vel, *ang_pos, *lin_vel, *lin_pos, *self.action, *aux_state]
self.state = np.concatenate(
[
ang_vel,
ang_pos,
lin_vel,
lin_pos,
self.action,
aux_state,
],
axis=-1,
)
elif self.angle_representation == 1:
self.state = np.array(
[*ang_vel, *quarternion, *lin_vel, *lin_pos, *self.action, *aux_state]
self.state = np.concatenate(
[ang_vel, quaternion, lin_vel, lin_pos, self.action, aux_state], axis=-1
)

def compute_term_trunc_reward(self) -> None:
Expand Down
189 changes: 189 additions & 0 deletions PyFlyt/gym_envs/quadx_envs/quadx_pole_balance_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
"""QuadX Pole Balance Environment."""
from __future__ import annotations

from typing import Any, Literal

import numpy as np
from gymnasium import spaces

from PyFlyt.gym_envs.quadx_envs.quadx_base_env import QuadXBaseEnv
from PyFlyt.gym_envs.utils.pole_handler import PoleHandler


class QuadXPoleBalanceEnv(QuadXBaseEnv):
"""Simple Hover Environment with the additional goal of keeping a pole upright.
Actions are vp, vq, vr, T, ie: angular rates and thrust.
The target is to not crash and not let the pole hit the ground for the longest time possible.
Args:
----
sparse_reward (bool): whether to use sparse rewards or not.
flight_mode (int): the flight mode of the UAV
flight_dome_size (float): size of the allowable flying area.
max_duration_seconds (float): maximum simulation time of the environment.
angle_representation (Literal["euler", "quaternion"]): can be "euler" or "quaternion".
agent_hz (int): looprate of the agent to environment interaction.
render_mode (None | Literal["human", "rgb_array"]): render_mode
render_resolution (tuple[int, int]): render_resolution.
"""

def __init__(
self,
sparse_reward: bool = False,
flight_mode: int = 0,
flight_dome_size: float = 3.0,
max_duration_seconds: float = 20.0,
angle_representation: Literal["euler", "quaternion"] = "quaternion",
agent_hz: int = 40,
render_mode: None | Literal["human", "rgb_array"] = None,
render_resolution: tuple[int, int] = (480, 480),
):
"""__init__.
Args:
----
sparse_reward (bool): whether to use sparse rewards or not.
flight_mode (int): the flight mode of the UAV
flight_dome_size (float): size of the allowable flying area.
max_duration_seconds (float): maximum simulation time of the environment.
angle_representation (Literal["euler", "quaternion"]): can be "euler" or "quaternion".
agent_hz (int): looprate of the agent to environment interaction.
render_mode (None | Literal["human", "rgb_array"]): render_mode
render_resolution (tuple[int, int]): render_resolution.
"""
super().__init__(
flight_mode=flight_mode,
flight_dome_size=flight_dome_size,
max_duration_seconds=max_duration_seconds,
angle_representation=angle_representation,
agent_hz=agent_hz,
render_mode=render_mode,
render_resolution=render_resolution,
)
# init the pole
self.pole = PoleHandler()

"""GYMNASIUM STUFF"""
# Define observation space
self.observation_space = spaces.Box(
low=-np.inf,
high=np.inf,
shape=(
self.combined_space.shape[0] + self.pole.observation_space.shape[0],
),
dtype=np.float64,
)

""" ENVIRONMENT CONSTANTS """
self.sparse_reward = sparse_reward

def reset(
self, *, seed: None | int = None, options: None | dict[str, Any] = dict()
) -> tuple[np.ndarray, dict[str, Any]]:
"""reset.
Args:
----
seed: seed to pass to the base environment.
options: None
"""
super().begin_reset(
seed,
options,
drone_options={"drone_model": "primitive_drone"},
)
self.pole.reset(p=self.env, start_location=np.array([0.0, 0.0, 1.55]))
super().end_reset(seed, options)

return self.state, self.info

def compute_state(self) -> None:
"""Computes the state of the current timestep.
This returns the observation.
- ang_vel (vector of 3 values)
- ang_pos (vector of 3/4 values)
- lin_vel (vector of 3 values)
- lin_pos (vector of 3 values)
- previous_action (vector of 4 values)
- auxiliary information (vector of 4 values)
- 12 values for the pole's positions relative to self:
------ top position XYZ
------ bottom position XYZ
------ top velocity XYZ
------ bottom velocity XYZ
- auxiliary information (vector of 4 values)
"""
# compute attitude of self
ang_vel, ang_pos, lin_vel, lin_pos, quaternion = super().compute_attitude()
aux_state = super().compute_auxiliary()
rotation = (
np.array(self.env.getMatrixFromQuaternion(quaternion)).reshape(3, 3).T
)

# compute the pole's states
(
pole_top_pos,
pole_top_vel,
pole_bot_pos,
pole_bot_vel,
) = self.pole.compute_state(
rotation=rotation,
uav_lin_pos=lin_pos,
uav_lin_vel=lin_vel,
)

# combine everything
if self.angle_representation == 0:
self.state = np.concatenate(
[
ang_vel,
ang_pos,
lin_vel,
lin_pos,
self.action,
aux_state,
pole_top_pos,
pole_bot_pos,
pole_top_vel,
pole_bot_vel,
],
axis=-1,
)
elif self.angle_representation == 1:
self.state = np.concatenate(
[
ang_vel,
quaternion,
lin_vel,
lin_pos,
self.action,
aux_state,
pole_top_pos,
pole_bot_pos,
pole_top_vel,
pole_bot_vel,
],
axis=-1,
)

def compute_term_trunc_reward(self) -> None:
"""Computes the termination, truncation, and reward of the current timestep."""
super().compute_base_term_trunc_reward()

if not self.sparse_reward:
# distance from 0, 0, 1 hover point
linear_distance = np.linalg.norm(
self.env.state(0)[-1] - np.array([0.0, 0.0, 1.0])
)

# how far are we from 0 roll pitch
angular_distance = np.linalg.norm(self.env.state(0)[1][:2])

self.reward -= linear_distance + angular_distance
self.reward -= self.pole.leaningness
self.reward += 1.0
Loading

0 comments on commit be0a846

Please sign in to comment.