From cabd6e87fa8bcb4d8099762ccecdf2301edea838 Mon Sep 17 00:00:00 2001 From: liangpan99 Date: Wed, 9 Aug 2023 22:59:22 +0800 Subject: [PATCH 1/2] fix obs bug after resetting envs --- ase/env/tasks/humanoid.py | 5 +++- ase/env/tasks/humanoid_amp.py | 41 ++++++++++++++++++++++++----- ase/env/tasks/humanoid_amp_getup.py | 6 +++++ ase/utils/motion_lib.py | 36 +++++++++++++++++++++++++ 4 files changed, 81 insertions(+), 7 deletions(-) diff --git a/ase/env/tasks/humanoid.py b/ase/env/tasks/humanoid.py index 17d36101..8a67e0e8 100644 --- a/ase/env/tasks/humanoid.py +++ b/ase/env/tasks/humanoid.py @@ -116,6 +116,9 @@ def __init__(self, cfg, sim_params, physics_engine, device_type, device_id, head self._rigid_body_vel = rigid_body_state_reshaped[..., :self.num_bodies, 7:10] self._rigid_body_ang_vel = rigid_body_state_reshaped[..., :self.num_bodies, 10:13] + self._initial_humanoid_rigid_body_states = rigid_body_state_reshaped[..., :self.num_bodies].clone() + self._initial_humanoid_rigid_body_states[..., 7:13] = 0 + contact_force_tensor = gymtorch.wrap_tensor(contact_force_tensor) self._contact_forces = contact_force_tensor.view(self.num_envs, bodies_per_env, 3)[..., :self.num_bodies, :] @@ -641,7 +644,7 @@ def compute_humanoid_observations_max(body_pos, body_rot, body_vel, body_ang_vel flat_local_body_rot_obs = torch_utils.quat_to_tan_norm(flat_local_body_rot) local_body_rot_obs = flat_local_body_rot_obs.reshape(body_rot.shape[0], body_rot.shape[1] * flat_local_body_rot_obs.shape[1]) - if (local_root_obs): + if (not local_root_obs): root_rot_obs = torch_utils.quat_to_tan_norm(root_rot) local_body_rot_obs[..., 0:6] = root_rot_obs diff --git a/ase/env/tasks/humanoid_amp.py b/ase/env/tasks/humanoid_amp.py index 09246c17..b69453b8 100644 --- a/ase/env/tasks/humanoid_amp.py +++ b/ase/env/tasks/humanoid_amp.py @@ -33,7 +33,7 @@ from isaacgym import gymapi from isaacgym import gymtorch -from env.tasks.humanoid import Humanoid, dof_to_obs +from env.tasks.humanoid import Humanoid, dof_to_obs, compute_humanoid_observations_max from utils import gym_util from utils.motion_lib import MotionLib from isaacgym.torch_utils import * @@ -73,6 +73,8 @@ def __init__(self, cfg, sim_params, physics_engine, device_type, device_id, head self._amp_obs_demo_buf = None + self._kinematic_humanoid_rigid_body_states = torch.zeros((self.num_envs, self.num_bodies, 13), device=self.device, dtype=torch.float) + return def post_physics_step(self): @@ -176,6 +178,15 @@ def _reset_actors(self, env_ids): self._reset_hybrid_state_init(env_ids) else: assert(False), "Unsupported state initialization strategy: {:s}".format(str(self._state_init)) + + if (len(self._reset_default_env_ids) > 0): + self._kinematic_humanoid_rigid_body_states[self._reset_default_env_ids] = self._initial_humanoid_rigid_body_states[self._reset_default_env_ids] + + if (len(self._reset_ref_env_ids) > 0): + body_pos, body_rot, body_vel, body_ang_vel \ + = self._motion_lib.get_motion_state_max(self._reset_ref_motion_ids, self._reset_ref_motion_times) + self._kinematic_humanoid_rigid_body_states[self._reset_ref_env_ids] = torch.cat((body_pos, body_rot, body_vel, body_ang_vel), dim=-1) + return def _reset_default(self, env_ids): @@ -283,8 +294,8 @@ def _update_hist_amp_obs(self, env_ids=None): return def _compute_amp_observations(self, env_ids=None): - key_body_pos = self._rigid_body_pos[:, self._key_body_ids, :] if (env_ids is None): + key_body_pos = self._rigid_body_pos[:, self._key_body_ids, :] self._curr_amp_obs_buf[:] = build_amp_observations(self._rigid_body_pos[:, 0, :], self._rigid_body_rot[:, 0, :], self._rigid_body_vel[:, 0, :], @@ -293,15 +304,33 @@ def _compute_amp_observations(self, env_ids=None): self._local_root_obs, self._root_height_obs, self._dof_obs_size, self._dof_offsets) else: - self._curr_amp_obs_buf[env_ids] = build_amp_observations(self._rigid_body_pos[env_ids][:, 0, :], - self._rigid_body_rot[env_ids][:, 0, :], - self._rigid_body_vel[env_ids][:, 0, :], - self._rigid_body_ang_vel[env_ids][:, 0, :], + kinematic_rigid_body_pos = self._kinematic_humanoid_rigid_body_states[:, :, 0:3] + key_body_pos = kinematic_rigid_body_pos[:, self._key_body_ids, :] + self._curr_amp_obs_buf[env_ids] = build_amp_observations(self._kinematic_humanoid_rigid_body_states[env_ids, 0, 0:3], + self._kinematic_humanoid_rigid_body_states[env_ids, 0, 3:7], + self._kinematic_humanoid_rigid_body_states[env_ids, 0, 7:10], + self._kinematic_humanoid_rigid_body_states[env_ids, 0, 10:13], self._dof_pos[env_ids], self._dof_vel[env_ids], key_body_pos[env_ids], self._local_root_obs, self._root_height_obs, self._dof_obs_size, self._dof_offsets) return + def _compute_humanoid_obs(self, env_ids=None): + if (env_ids is None): + body_pos = self._rigid_body_pos + body_rot = self._rigid_body_rot + body_vel = self._rigid_body_vel + body_ang_vel = self._rigid_body_ang_vel + else: + body_pos = self._kinematic_humanoid_rigid_body_states[env_ids, :, 0:3] + body_rot = self._kinematic_humanoid_rigid_body_states[env_ids, :, 3:7] + body_vel = self._kinematic_humanoid_rigid_body_states[env_ids, :, 7:10] + body_ang_vel = self._kinematic_humanoid_rigid_body_states[env_ids, :, 10:13] + + obs = compute_humanoid_observations_max(body_pos, body_rot, body_vel, body_ang_vel, self._local_root_obs, + self._root_height_obs) + + return obs ##################################################################### ###=========================jit functions=========================### diff --git a/ase/env/tasks/humanoid_amp_getup.py b/ase/env/tasks/humanoid_amp_getup.py index f0a0c653..05795f7b 100644 --- a/ase/env/tasks/humanoid_amp_getup.py +++ b/ase/env/tasks/humanoid_amp_getup.py @@ -101,6 +101,11 @@ def _generate_fall_states(self): self._fall_dof_pos = self._dof_pos.clone() self._fall_dof_vel = torch.zeros_like(self._dof_vel, device=self.device, dtype=torch.float) + rigid_body_state = gymtorch.wrap_tensor(self.gym.acquire_rigid_body_state_tensor(self.sim)) + bodies_per_env = rigid_body_state.shape[0] // self.num_envs + rigid_body_state_reshaped = rigid_body_state.view(self.num_envs, bodies_per_env, 13) + self._fall_rigid_body_states = rigid_body_state_reshaped[..., :self.num_bodies].clone() + return def _reset_actors(self, env_ids): @@ -121,6 +126,7 @@ def _reset_actors(self, env_ids): fall_ids = nonrecovery_ids[fall_mask] if (len(fall_ids) > 0): self._reset_fall_episode(fall_ids) + self._kinematic_humanoid_rigid_body_states[fall_ids] = self._fall_rigid_body_states[fall_ids] nonfall_ids = nonrecovery_ids[torch.logical_not(fall_mask)] diff --git a/ase/utils/motion_lib.py b/ase/utils/motion_lib.py index c625a383..736c9a30 100644 --- a/ase/utils/motion_lib.py +++ b/ase/utils/motion_lib.py @@ -107,6 +107,9 @@ def __init__(self, motion_file, dof_body_ids, dof_offsets, self.gravs = torch.cat([m.global_root_angular_velocity for m in motions], dim=0).float() self.dvs = torch.cat([m.dof_vels for m in motions], dim=0).float() + self.gvs = torch.cat([m.global_velocity for m in motions], dim=0).float() + self.gavs = torch.cat([m.global_angular_velocity for m in motions], dim=0).float() + lengths = self._motion_num_frames lengths_shifted = lengths.roll(1) lengths_shifted[0] = 0 @@ -199,6 +202,39 @@ def get_motion_state(self, motion_ids, motion_times): return root_pos, root_rot, dof_pos, root_vel, root_ang_vel, dof_vel, key_pos + def get_motion_state_max(self, motion_ids, motion_times): + n = len(motion_ids) + num_bodies = self._get_num_bodies() + num_key_bodies = self._key_body_ids.shape[0] + + motion_len = self._motion_lengths[motion_ids] + num_frames = self._motion_num_frames[motion_ids] + dt = self._motion_dt[motion_ids] + + frame_idx0, frame_idx1, blend = self._calc_frame_blend(motion_times, motion_len, num_frames, dt) + + f0l = frame_idx0 + self.length_starts[motion_ids] + f1l = frame_idx1 + self.length_starts[motion_ids] + + body_pos0 = self.gts[f0l] + body_pos1 = self.gts[f1l] + body_rot0 = self.grs[f0l] + body_rot1 = self.grs[f1l] + body_vel = self.gvs[f0l] + body_ang_vel = self.gavs[f0l] + + vals = [body_pos0, body_pos1, body_rot0, body_rot1, body_vel, body_ang_vel] + for v in vals: + assert v.dtype != torch.float64 + + blend = blend.unsqueeze(-1) + blend_exp = blend.unsqueeze(-1) + + body_pos = (1.0 - blend_exp) * body_pos0 + blend_exp * body_pos1 + body_rot = torch_utils.slerp(body_rot0, body_rot1, blend_exp) + + return body_pos, body_rot, body_vel, body_ang_vel + def _load_motions(self, motion_file): self._motions = [] self._motion_lengths = [] From f4c2ccf220d5ad9409e0a28f2a650ea15cc99568 Mon Sep 17 00:00:00 2001 From: liangpan99 Date: Mon, 4 Sep 2023 02:00:43 +0800 Subject: [PATCH 2/2] interpolate in reduced coordinate --- ase/utils/motion_lib.py | 35 ++++++++++++++++++++++++++--------- 1 file changed, 26 insertions(+), 9 deletions(-) diff --git a/ase/utils/motion_lib.py b/ase/utils/motion_lib.py index 736c9a30..bcf3b56a 100644 --- a/ase/utils/motion_lib.py +++ b/ase/utils/motion_lib.py @@ -30,7 +30,7 @@ import os import yaml -from poselib.poselib.skeleton.skeleton3d import SkeletonMotion +from poselib.poselib.skeleton.skeleton3d import SkeletonMotion, SkeletonState from poselib.poselib.core.rotation3d import * from isaacgym.torch_utils import * @@ -110,6 +110,8 @@ def __init__(self, motion_file, dof_body_ids, dof_offsets, self.gvs = torch.cat([m.global_velocity for m in motions], dim=0).float() self.gavs = torch.cat([m.global_angular_velocity for m in motions], dim=0).float() + self.skeleton_tree = motions[0].skeleton_tree + lengths = self._motion_num_frames lengths_shifted = lengths.roll(1) lengths_shifted[0] = 0 @@ -216,23 +218,38 @@ def get_motion_state_max(self, motion_ids, motion_times): f0l = frame_idx0 + self.length_starts[motion_ids] f1l = frame_idx1 + self.length_starts[motion_ids] - body_pos0 = self.gts[f0l] - body_pos1 = self.gts[f1l] - body_rot0 = self.grs[f0l] - body_rot1 = self.grs[f1l] + root_pos0 = self.gts[f0l, 0] + root_pos1 = self.gts[f1l, 0] + local_rot0 = self.lrs[f0l] + local_rot1 = self.lrs[f1l] + + # velocities of rigid bodies are still incorrect, + # because the relationship between joint rotations and rigid body vel is nonlinear. body_vel = self.gvs[f0l] + body_ang_vel = self.gavs[f0l] - vals = [body_pos0, body_pos1, body_rot0, body_rot1, body_vel, body_ang_vel] + vals = [root_pos0, root_pos1, local_rot0, local_rot1, body_vel, body_ang_vel] for v in vals: assert v.dtype != torch.float64 blend = blend.unsqueeze(-1) blend_exp = blend.unsqueeze(-1) - body_pos = (1.0 - blend_exp) * body_pos0 + blend_exp * body_pos1 - body_rot = torch_utils.slerp(body_rot0, body_rot1, blend_exp) - + # interpolate in reduced coordinate + root_pos = (1.0 - blend) * root_pos0 + blend * root_pos1 + local_rot = torch_utils.slerp(local_rot0, local_rot1, blend_exp) + + # transform to maximal coordinate + new_sk_state = SkeletonState.from_rotation_and_root_translation( + self.skeleton_tree, + local_rot, + root_pos, + is_local=True + ).global_repr() + body_pos = new_sk_state.global_translation + body_rot = new_sk_state.global_rotation + return body_pos, body_rot, body_vel, body_ang_vel def _load_motions(self, motion_file):