Skip to content

Commit

Permalink
update examples
Browse files Browse the repository at this point in the history
Co-Authored-By: Roger Creus <31919499+roger-creus@users.noreply.github.com>
  • Loading branch information
yuanmingqi and roger-creus committed May 15, 2024
1 parent 1b84a34 commit 3cf0e46
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 41 deletions.
86 changes: 47 additions & 39 deletions examples/rlexplore/4 mixed_intrinsic_rewards.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import gymnasium as gym\n",
"import numpy as np\n",
"import torch as th\n",
"\n",
"from rllte.env.utils import Gymnasium2Torch\n",
"import sys\n",
"sys.path.append('../../')\n",
"from rllte.env import make_atari_env\n",
"from rllte.xplore.reward import Fabric, RE3, ICM\n",
"from rllte.agent import PPO"
]
Expand All @@ -33,82 +33,90 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"**Create a fake Atari environment with image observations**"
"**Use the `Fabric` class to create a mixed intrinsic reward**"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"class FakeAtari(gym.Env):\n",
" def __init__(self):\n",
" self.action_space = gym.spaces.Discrete(7)\n",
" self.observation_space = gym.spaces.Box(low=0, high=1, shape=(4, 84, 84))\n",
" self.count = 0\n",
"class TwoMixed(Fabric):\n",
" def __init__(self, m1, m2):\n",
" super().__init__(m1, m2)\n",
" \n",
" def compute(self, samples, sync):\n",
" rwd1, rwd2 = super().compute(samples, sync)\n",
"\n",
" def reset(self):\n",
" self.count = 0\n",
" return self.observation_space.sample(), {}\n",
" return rwd1 + rwd2\n",
"\n",
" def step(self, action):\n",
" self.count += 1\n",
" if self.count > 100 and np.random.rand() < 0.1:\n",
" term = trunc = True\n",
" else:\n",
" term = trunc = False\n",
" return self.observation_space.sample(), 0, term, trunc, {}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Use the `Fabric` class to create a mixed intrinsic reward**"
"class ThreeMixed(Fabric):\n",
" def __init__(self, m1, m2, m3):\n",
" super().__init__(m1, m2, m3)\n",
" \n",
" def compute(self, samples, sync):\n",
" rwd1, rwd2, rw3 = super().compute(samples, sync)\n",
"\n",
" return (rwd1 + rwd2) * rw3"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# set the parameters\n",
"device = 'cuda' if th.cuda.is_available() else 'cpu'\n",
"n_steps = 128\n",
"n_envs = 8\n",
"# create the vectorized environments\n",
"envs = gym.vector.AsyncVectorEnv([FakeAtari for _ in range(n_envs)])\n",
"# wrap the environments to convert the observations to torch tensors\n",
"envs = Gymnasium2Torch(envs, device)\n",
"envs = make_atari_env(env_id='BreakoutNoFrameskip-v4', device=device, num_envs=8)\n",
"# create two intrinsic reward functions\n",
"irs1 = ICM(envs, device)\n",
"irs2 = RE3(envs, device)\n",
"# create the mixed intrinsic reward function\n",
"irs = Fabric([irs1, irs2])"
"irs = TwoMixed(irs1, irs2)"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"# start the training\n",
"device = 'cuda' if th.cuda.is_available() else 'cpu'\n",
"# create the PPO agent\n",
"agent = PPO(envs, device=device)\n",
"# set the intrinsic reward module\n",
"agent.set(reward=irs)\n",
"# train the agent\n",
"agent.train(10000)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "marllib",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python"
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.18"
}
},
"nbformat": 4,
Expand Down
2 changes: 1 addition & 1 deletion rllte/common/prototype/base_reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def init_normalization(self) -> None:
self.obs_norm.update(all_next_obs)
all_next_obs = []
except:
# for the normal vectorized environments
# for the normal vectorized environments, old gym output
_ = self.envs.reset()
if self.obs_norm_type == "rms":
all_next_obs = []
Expand Down
2 changes: 1 addition & 1 deletion rllte/common/prototype/on_policy_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def train(
"terminateds": self.storage.terminateds,
"truncateds": self.storage.truncateds,
"next_observations": self.storage.observations[1:], # type: ignore
}
}, sync=True
)
# just plus the intrinsic rewards to the extrinsic rewards
self.storage.rewards += intrinsic_rewards.to(self.device)
Expand Down
3 changes: 3 additions & 0 deletions rllte/xplore/reward/fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,14 @@ class Fabric(BaseReward):
"""

def __init__(self, *rewards: BaseReward) -> None:
assert len(rewards) >= 2, "At least two intrinsic rewards are needed!"
for rwd in rewards:
assert isinstance(
rwd, BaseReward
), "The input rewards must be the instance of BaseReward!"
self.rewards = list(rewards)
self.device = self.rewards[0].device


def watch(
self,
Expand Down

0 comments on commit 3cf0e46

Please sign in to comment.