Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented methods to save and restore PyBullet states. #33

Merged
merged 15 commits into from
Jul 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
author = 'Quentin Gallouédec'

# The full version, including alpha/beta/rc tags
release = 'v2.0.3'
release = 'v2.0.4'


# -- General configuration ---------------------------------------------------
Expand Down
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ Welcome to panda-gym's documentation!

usage/environments
usage/manual_control
usage/save_restore_state
usage/train_with_sb3

.. toctree::
Expand Down
31 changes: 31 additions & 0 deletions docs/usage/save_restore_state.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
.. _save_restore_states:

Save and Restore States
==============

It is possible to save a state of the entire simulation environment. This is useful if your application requires lookahead search. Below is an example of a greedy random search.

.. code-block:: python

import gym
import panda_gym

env = gym.make("PandaReachDense-v2", render=True)
obs = env.reset()

while True:
state_id = env.save_state()
best_action = None
rew = best_rew = env.task.compute_reward(
obs["achieved_goal"], obs["desired_goal"], None)

while rew <= best_rew:
env.restore_state(state_id)
a = env.action_space.sample()
_, rew, _, _ = env.step(a)

env.restore_state(state_id)
obs, _, _, _ = env.step(a)
env.remove_state(state_id)

env.close()
14 changes: 14 additions & 0 deletions panda_gym/envs/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ def __init__(self, robot: PyBulletRobot, task: Task) -> None:
)
self.action_space = self.robot.action_space
self.compute_reward = self.task.compute_reward
self._saved_goal = dict()

def _get_obs(self) -> Dict[str, np.ndarray]:
robot_obs = self.robot.get_obs() # robot state
Expand All @@ -246,6 +247,19 @@ def reset(self, seed: Optional[int] = None) -> Dict[str, np.ndarray]:
self.task.reset()
return self._get_obs()

def save_state(self) -> int:
state_id = self.sim.save_state()
self._saved_goal[state_id] = self.task.goal
return state_id

def restore_state(self, state_id: int) -> None:
self.sim.restore_state(state_id)
self.task.goal = self._saved_goal[state_id]

def remove_state(self, state_id: int) -> None:
self._saved_goal.pop(state_id)
self.sim.remove_state(state_id)

def step(self, action: np.ndarray) -> Tuple[Dict[str, np.ndarray], float, bool, Dict[str, Any]]:
self.robot.set_action(action)
self.sim.step()
Expand Down
25 changes: 25 additions & 0 deletions panda_gym/pybullet.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,31 @@ def close(self) -> None:
"""Close the simulation."""
self.physics_client.disconnect()

def save_state(self) -> int:
"""Save the current simulation state.

Returns:
int: A state id assigned by PyBullet, which is the first non-negative
integer available for indexing.
"""
return self.physics_client.saveState()

def restore_state(self, state_id: int) -> None:
"""Restore a simulation state.

Args:
state_id: The simulation state id returned by save_state().
"""
self.physics_client.restoreState(state_id)

def remove_state(self, state_id: int) -> None:
"""Remove a simulation state. This will make this state_id available again for returning in save_state().

Args:
state_id: The simulation state id returned by save_state().
"""
self.physics_client.removeState(state_id)

def render(
self,
mode: str = "human",
Expand Down
2 changes: 1 addition & 1 deletion panda_gym/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.0.3
2.0.4
36 changes: 36 additions & 0 deletions test/save_and_restore_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import gym
import numpy as np
import pybullet
import pytest

import panda_gym


def test_save_and_restore_state():
env = gym.make("PandaReach-v2")
env.reset()

state_id = env.save_state()

# Perform the action
action = env.action_space.sample()
next_obs1, reward, done, info = env.step(action)

# Restore and perform the same action
env.reset()
env.restore_state(state_id)
next_obs2, reward, done, info = env.step(action)

# The observations in both cases should be equals
assert np.all(next_obs1["achieved_goal"] == next_obs2["achieved_goal"])
assert np.all(next_obs1["observation"] == next_obs2["observation"])
assert np.all(next_obs1["desired_goal"] == next_obs2["desired_goal"])


def test_remove_state():
env = gym.make("PandaReach-v2")
env.reset()
state_id = env.save_state()
env.remove_state(state_id)
with pytest.raises(pybullet.error):
env.restore_state(state_id)