Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixing observation bug after resetting envs #53

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion ase/env/tasks/humanoid.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,9 @@ def __init__(self, cfg, sim_params, physics_engine, device_type, device_id, head
self._rigid_body_vel = rigid_body_state_reshaped[..., :self.num_bodies, 7:10]
self._rigid_body_ang_vel = rigid_body_state_reshaped[..., :self.num_bodies, 10:13]

self._initial_humanoid_rigid_body_states = rigid_body_state_reshaped[..., :self.num_bodies].clone()
self._initial_humanoid_rigid_body_states[..., 7:13] = 0

contact_force_tensor = gymtorch.wrap_tensor(contact_force_tensor)
self._contact_forces = contact_force_tensor.view(self.num_envs, bodies_per_env, 3)[..., :self.num_bodies, :]

Expand Down Expand Up @@ -641,7 +644,7 @@ def compute_humanoid_observations_max(body_pos, body_rot, body_vel, body_ang_vel
flat_local_body_rot_obs = torch_utils.quat_to_tan_norm(flat_local_body_rot)
local_body_rot_obs = flat_local_body_rot_obs.reshape(body_rot.shape[0], body_rot.shape[1] * flat_local_body_rot_obs.shape[1])

if (local_root_obs):
if (not local_root_obs):
root_rot_obs = torch_utils.quat_to_tan_norm(root_rot)
local_body_rot_obs[..., 0:6] = root_rot_obs

Expand Down
41 changes: 35 additions & 6 deletions ase/env/tasks/humanoid_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from isaacgym import gymapi
from isaacgym import gymtorch

from env.tasks.humanoid import Humanoid, dof_to_obs
from env.tasks.humanoid import Humanoid, dof_to_obs, compute_humanoid_observations_max
from utils import gym_util
from utils.motion_lib import MotionLib
from isaacgym.torch_utils import *
Expand Down Expand Up @@ -73,6 +73,8 @@ def __init__(self, cfg, sim_params, physics_engine, device_type, device_id, head

self._amp_obs_demo_buf = None

self._kinematic_humanoid_rigid_body_states = torch.zeros((self.num_envs, self.num_bodies, 13), device=self.device, dtype=torch.float)

return

def post_physics_step(self):
Expand Down Expand Up @@ -176,6 +178,15 @@ def _reset_actors(self, env_ids):
self._reset_hybrid_state_init(env_ids)
else:
assert(False), "Unsupported state initialization strategy: {:s}".format(str(self._state_init))

if (len(self._reset_default_env_ids) > 0):
self._kinematic_humanoid_rigid_body_states[self._reset_default_env_ids] = self._initial_humanoid_rigid_body_states[self._reset_default_env_ids]

if (len(self._reset_ref_env_ids) > 0):
body_pos, body_rot, body_vel, body_ang_vel \
= self._motion_lib.get_motion_state_max(self._reset_ref_motion_ids, self._reset_ref_motion_times)
self._kinematic_humanoid_rigid_body_states[self._reset_ref_env_ids] = torch.cat((body_pos, body_rot, body_vel, body_ang_vel), dim=-1)

return

def _reset_default(self, env_ids):
Expand Down Expand Up @@ -283,8 +294,8 @@ def _update_hist_amp_obs(self, env_ids=None):
return

def _compute_amp_observations(self, env_ids=None):
key_body_pos = self._rigid_body_pos[:, self._key_body_ids, :]
if (env_ids is None):
key_body_pos = self._rigid_body_pos[:, self._key_body_ids, :]
self._curr_amp_obs_buf[:] = build_amp_observations(self._rigid_body_pos[:, 0, :],
self._rigid_body_rot[:, 0, :],
self._rigid_body_vel[:, 0, :],
Expand All @@ -293,15 +304,33 @@ def _compute_amp_observations(self, env_ids=None):
self._local_root_obs, self._root_height_obs,
self._dof_obs_size, self._dof_offsets)
else:
self._curr_amp_obs_buf[env_ids] = build_amp_observations(self._rigid_body_pos[env_ids][:, 0, :],
self._rigid_body_rot[env_ids][:, 0, :],
self._rigid_body_vel[env_ids][:, 0, :],
self._rigid_body_ang_vel[env_ids][:, 0, :],
kinematic_rigid_body_pos = self._kinematic_humanoid_rigid_body_states[:, :, 0:3]
key_body_pos = kinematic_rigid_body_pos[:, self._key_body_ids, :]
self._curr_amp_obs_buf[env_ids] = build_amp_observations(self._kinematic_humanoid_rigid_body_states[env_ids, 0, 0:3],
self._kinematic_humanoid_rigid_body_states[env_ids, 0, 3:7],
self._kinematic_humanoid_rigid_body_states[env_ids, 0, 7:10],
self._kinematic_humanoid_rigid_body_states[env_ids, 0, 10:13],
self._dof_pos[env_ids], self._dof_vel[env_ids], key_body_pos[env_ids],
self._local_root_obs, self._root_height_obs,
self._dof_obs_size, self._dof_offsets)
return

def _compute_humanoid_obs(self, env_ids=None):
if (env_ids is None):
body_pos = self._rigid_body_pos
body_rot = self._rigid_body_rot
body_vel = self._rigid_body_vel
body_ang_vel = self._rigid_body_ang_vel
else:
body_pos = self._kinematic_humanoid_rigid_body_states[env_ids, :, 0:3]
body_rot = self._kinematic_humanoid_rigid_body_states[env_ids, :, 3:7]
body_vel = self._kinematic_humanoid_rigid_body_states[env_ids, :, 7:10]
body_ang_vel = self._kinematic_humanoid_rigid_body_states[env_ids, :, 10:13]

obs = compute_humanoid_observations_max(body_pos, body_rot, body_vel, body_ang_vel, self._local_root_obs,
self._root_height_obs)

return obs

#####################################################################
###=========================jit functions=========================###
Expand Down
6 changes: 6 additions & 0 deletions ase/env/tasks/humanoid_amp_getup.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,11 @@ def _generate_fall_states(self):
self._fall_dof_pos = self._dof_pos.clone()
self._fall_dof_vel = torch.zeros_like(self._dof_vel, device=self.device, dtype=torch.float)

rigid_body_state = gymtorch.wrap_tensor(self.gym.acquire_rigid_body_state_tensor(self.sim))
bodies_per_env = rigid_body_state.shape[0] // self.num_envs
rigid_body_state_reshaped = rigid_body_state.view(self.num_envs, bodies_per_env, 13)
self._fall_rigid_body_states = rigid_body_state_reshaped[..., :self.num_bodies].clone()

return

def _reset_actors(self, env_ids):
Expand All @@ -121,6 +126,7 @@ def _reset_actors(self, env_ids):
fall_ids = nonrecovery_ids[fall_mask]
if (len(fall_ids) > 0):
self._reset_fall_episode(fall_ids)
self._kinematic_humanoid_rigid_body_states[fall_ids] = self._fall_rigid_body_states[fall_ids]


nonfall_ids = nonrecovery_ids[torch.logical_not(fall_mask)]
Expand Down
55 changes: 54 additions & 1 deletion ase/utils/motion_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import os
import yaml

from poselib.poselib.skeleton.skeleton3d import SkeletonMotion
from poselib.poselib.skeleton.skeleton3d import SkeletonMotion, SkeletonState
from poselib.poselib.core.rotation3d import *
from isaacgym.torch_utils import *

Expand Down Expand Up @@ -107,6 +107,11 @@ def __init__(self, motion_file, dof_body_ids, dof_offsets,
self.gravs = torch.cat([m.global_root_angular_velocity for m in motions], dim=0).float()
self.dvs = torch.cat([m.dof_vels for m in motions], dim=0).float()

self.gvs = torch.cat([m.global_velocity for m in motions], dim=0).float()
self.gavs = torch.cat([m.global_angular_velocity for m in motions], dim=0).float()

self.skeleton_tree = motions[0].skeleton_tree

lengths = self._motion_num_frames
lengths_shifted = lengths.roll(1)
lengths_shifted[0] = 0
Expand Down Expand Up @@ -199,6 +204,54 @@ def get_motion_state(self, motion_ids, motion_times):

return root_pos, root_rot, dof_pos, root_vel, root_ang_vel, dof_vel, key_pos

def get_motion_state_max(self, motion_ids, motion_times):
n = len(motion_ids)
num_bodies = self._get_num_bodies()
num_key_bodies = self._key_body_ids.shape[0]

motion_len = self._motion_lengths[motion_ids]
num_frames = self._motion_num_frames[motion_ids]
dt = self._motion_dt[motion_ids]

frame_idx0, frame_idx1, blend = self._calc_frame_blend(motion_times, motion_len, num_frames, dt)

f0l = frame_idx0 + self.length_starts[motion_ids]
f1l = frame_idx1 + self.length_starts[motion_ids]

root_pos0 = self.gts[f0l, 0]
root_pos1 = self.gts[f1l, 0]
local_rot0 = self.lrs[f0l]
local_rot1 = self.lrs[f1l]

# velocities of rigid bodies are still incorrect,
# because the relationship between joint rotations and rigid body vel is nonlinear.
body_vel = self.gvs[f0l]

body_ang_vel = self.gavs[f0l]

vals = [root_pos0, root_pos1, local_rot0, local_rot1, body_vel, body_ang_vel]
for v in vals:
assert v.dtype != torch.float64

blend = blend.unsqueeze(-1)
blend_exp = blend.unsqueeze(-1)

# interpolate in reduced coordinate
root_pos = (1.0 - blend) * root_pos0 + blend * root_pos1
local_rot = torch_utils.slerp(local_rot0, local_rot1, blend_exp)

# transform to maximal coordinate
new_sk_state = SkeletonState.from_rotation_and_root_translation(
self.skeleton_tree,
local_rot,
root_pos,
is_local=True
).global_repr()
body_pos = new_sk_state.global_translation
body_rot = new_sk_state.global_rotation

return body_pos, body_rot, body_vel, body_ang_vel

def _load_motions(self, motion_file):
self._motions = []
self._motion_lengths = []
Expand Down