Skip to content

Commit bdc9ae1

Browse files
louixpqgallouedec
andauthored
Implemented methods to save and restore PyBullet states. (#33)
* Implemented methods to save and restore PyBullet states. * Fixed typos. * Added docs for save_state() and remove_state(). * Make save and restore state docs visible in index. * Added unit tests for save and restore states. * Added unit test for remove state. * Fixed save and restore test logic. * isort and black * Test for desired goal consistency during state saving and restoring. * Save and restore task goal. * Run linting. * `p` to `self.physics_client` * fix docstring style * Update version Co-authored-by: Quentin Gallouédec <[email protected]>
1 parent 2f634e2 commit bdc9ae1

File tree

7 files changed

+109
-2
lines changed

7 files changed

+109
-2
lines changed

docs/conf.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
author = 'Quentin Gallouédec'
2323

2424
# The full version, including alpha/beta/rc tags
25-
release = 'v2.0.3'
25+
release = 'v2.0.4'
2626

2727

2828
# -- General configuration ---------------------------------------------------

docs/index.rst

+1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ Welcome to panda-gym's documentation!
1919

2020
usage/environments
2121
usage/manual_control
22+
usage/save_restore_state
2223
usage/train_with_sb3
2324

2425
.. toctree::

docs/usage/save_restore_state.rst

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
.. _save_restore_states:
2+
3+
Save and Restore States
4+
==============
5+
6+
It is possible to save a state of the entire simulation environment. This is useful if your application requires lookahead search. Below is an example of a greedy random search.
7+
8+
.. code-block:: python
9+
10+
import gym
11+
import panda_gym
12+
13+
env = gym.make("PandaReachDense-v2", render=True)
14+
obs = env.reset()
15+
16+
while True:
17+
state_id = env.save_state()
18+
best_action = None
19+
rew = best_rew = env.task.compute_reward(
20+
obs["achieved_goal"], obs["desired_goal"], None)
21+
22+
while rew <= best_rew:
23+
env.restore_state(state_id)
24+
a = env.action_space.sample()
25+
_, rew, _, _ = env.step(a)
26+
27+
env.restore_state(state_id)
28+
obs, _, _, _ = env.step(a)
29+
env.remove_state(state_id)
30+
31+
env.close()

panda_gym/envs/core.py

+14
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ def __init__(self, robot: PyBulletRobot, task: Task) -> None:
227227
)
228228
self.action_space = self.robot.action_space
229229
self.compute_reward = self.task.compute_reward
230+
self._saved_goal = dict()
230231

231232
def _get_obs(self) -> Dict[str, np.ndarray]:
232233
robot_obs = self.robot.get_obs() # robot state
@@ -246,6 +247,19 @@ def reset(self, seed: Optional[int] = None) -> Dict[str, np.ndarray]:
246247
self.task.reset()
247248
return self._get_obs()
248249

250+
def save_state(self) -> int:
251+
state_id = self.sim.save_state()
252+
self._saved_goal[state_id] = self.task.goal
253+
return state_id
254+
255+
def restore_state(self, state_id: int) -> None:
256+
self.sim.restore_state(state_id)
257+
self.task.goal = self._saved_goal[state_id]
258+
259+
def remove_state(self, state_id: int) -> None:
260+
self._saved_goal.pop(state_id)
261+
self.sim.remove_state(state_id)
262+
249263
def step(self, action: np.ndarray) -> Tuple[Dict[str, np.ndarray], float, bool, Dict[str, Any]]:
250264
self.robot.set_action(action)
251265
self.sim.step()

panda_gym/pybullet.py

+25
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,31 @@ def close(self) -> None:
5757
"""Close the simulation."""
5858
self.physics_client.disconnect()
5959

60+
def save_state(self) -> int:
61+
"""Save the current simulation state.
62+
63+
Returns:
64+
int: A state id assigned by PyBullet, which is the first non-negative
65+
integer available for indexing.
66+
"""
67+
return self.physics_client.saveState()
68+
69+
def restore_state(self, state_id: int) -> None:
70+
"""Restore a simulation state.
71+
72+
Args:
73+
state_id: The simulation state id returned by save_state().
74+
"""
75+
self.physics_client.restoreState(state_id)
76+
77+
def remove_state(self, state_id: int) -> None:
78+
"""Remove a simulation state. This will make this state_id available again for returning in save_state().
79+
80+
Args:
81+
state_id: The simulation state id returned by save_state().
82+
"""
83+
self.physics_client.removeState(state_id)
84+
6085
def render(
6186
self,
6287
mode: str = "human",

panda_gym/version.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.0.3
1+
2.0.4

test/save_and_restore_test.py

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import gym
2+
import numpy as np
3+
import pybullet
4+
import pytest
5+
6+
import panda_gym
7+
8+
9+
def test_save_and_restore_state():
10+
env = gym.make("PandaReach-v2")
11+
env.reset()
12+
13+
state_id = env.save_state()
14+
15+
# Perform the action
16+
action = env.action_space.sample()
17+
next_obs1, reward, done, info = env.step(action)
18+
19+
# Restore and perform the same action
20+
env.reset()
21+
env.restore_state(state_id)
22+
next_obs2, reward, done, info = env.step(action)
23+
24+
# The observations in both cases should be equals
25+
assert np.all(next_obs1["achieved_goal"] == next_obs2["achieved_goal"])
26+
assert np.all(next_obs1["observation"] == next_obs2["observation"])
27+
assert np.all(next_obs1["desired_goal"] == next_obs2["desired_goal"])
28+
29+
30+
def test_remove_state():
31+
env = gym.make("PandaReach-v2")
32+
env.reset()
33+
state_id = env.save_state()
34+
env.remove_state(state_id)
35+
with pytest.raises(pybullet.error):
36+
env.restore_state(state_id)

0 commit comments

Comments
 (0)