Skip to content

Commit 7a55743

Browse files
Ilya KostrikovIlya Kostrikov
Ilya Kostrikov
authored and
Ilya Kostrikov
committed
Replace mujoco_py with dm_control.
1 parent 2af8162 commit 7a55743

File tree

5 files changed

+15
-55
lines changed

5 files changed

+15
-55
lines changed

gym/envs/mujoco/mujoco_env.py

+12-50
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99
import gym
1010

1111
try:
12-
import mujoco_py
12+
import dm_control.mujoco as dm_mujoco
1313
except ImportError as e:
1414
raise error.DependencyNotInstalled(
15-
"{}. (HINT: you need to install mujoco_py, and also perform the setup instructions here: https://github.com/openai/mujoco-py/.)".format(
15+
"{}. (HINT: you need to install dm_control)".format(
1616
e
1717
)
1818
)
@@ -51,8 +51,8 @@ def __init__(self, model_path, frame_skip):
5151
if not path.exists(fullpath):
5252
raise OSError(f"File {fullpath} does not exist")
5353
self.frame_skip = frame_skip
54-
self.model = mujoco_py.load_model_from_path(fullpath)
55-
self.sim = mujoco_py.MjSim(self.model)
54+
self.sim = dm_mujoco.Physics.from_xml_path(fullpath)
55+
self.model = self.sim.model
5656
self.data = self.sim.data
5757
self.viewer = None
5858
self._viewers = {}
@@ -111,11 +111,10 @@ def reset(self, seed: Optional[int] = None):
111111

112112
def set_state(self, qpos, qvel):
113113
assert qpos.shape == (self.model.nq,) and qvel.shape == (self.model.nv,)
114-
old_state = self.sim.get_state()
115-
new_state = mujoco_py.MjSimState(
116-
old_state.time, qpos, qvel, old_state.act, old_state.udd_state
117-
)
118-
self.sim.set_state(new_state)
114+
state = self.sim.get_state()
115+
state[:self.model.nq] = qpos
116+
state[self.model.nq:self.model.nq+self.model.nv] = qvel
117+
self.sim.set_state(state)
119118
self.sim.forward()
120119

121120
@property
@@ -138,55 +137,18 @@ def render(
138137
camera_id=None,
139138
camera_name=None,
140139
):
141-
if mode == "rgb_array" or mode == "depth_array":
142-
if camera_id is not None and camera_name is not None:
143-
raise ValueError(
144-
"Both `camera_id` and `camera_name` cannot be"
145-
" specified at the same time."
146-
)
147-
148-
no_camera_specified = camera_name is None and camera_id is None
149-
if no_camera_specified:
150-
camera_name = "track"
151-
152-
if camera_id is None and camera_name in self.model._camera_name2id:
153-
camera_id = self.model.camera_name2id(camera_name)
154-
155-
self._get_viewer(mode).render(width, height, camera_id=camera_id)
156-
157140
if mode == "rgb_array":
158-
# window size used for old mujoco-py:
159-
data = self._get_viewer(mode).read_pixels(width, height, depth=False)
160-
# original image is upside-down, so flip it
161-
return data[::-1, :, :]
162-
elif mode == "depth_array":
163-
self._get_viewer(mode).render(width, height)
164-
# window size used for old mujoco-py:
165-
# Extract depth part of the read_pixels() tuple
166-
data = self._get_viewer(mode).read_pixels(width, height, depth=True)[1]
167-
# original image is upside-down, so flip it
168-
return data[::-1, :]
169-
elif mode == "human":
170-
self._get_viewer(mode).render()
141+
camera_id = camera_id or 0
142+
return self.sim.render(height, width, camera_id)
143+
else:
144+
raise NotImplemented
171145

172146
def close(self):
173147
if self.viewer is not None:
174148
# self.viewer.finish()
175149
self.viewer = None
176150
self._viewers = {}
177151

178-
def _get_viewer(self, mode):
179-
self.viewer = self._viewers.get(mode)
180-
if self.viewer is None:
181-
if mode == "human":
182-
self.viewer = mujoco_py.MjViewer(self.sim)
183-
elif mode == "rgb_array" or mode == "depth_array":
184-
self.viewer = mujoco_py.MjRenderContextOffscreen(self.sim, -1)
185-
186-
self.viewer_setup()
187-
self._viewers[mode] = self.viewer
188-
return self.viewer
189-
190152
def get_body_com(self, body_name):
191153
return self.data.get_body_xpos(body_name)
192154

gym/envs/mujoco/pusher.py

-2
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
from gym import utils
33
from gym.envs.mujoco import mujoco_env
44

5-
import mujoco_py
6-
75

86
class PusherEnv(mujoco_env.MujocoEnv, utils.EzPickle):
97
def __init__(self):

requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
ale-py~=0.7
22
opencv-python>=3.0
33
box2d-py==2.3.5
4-
mujoco_py>=1.50, <2.0
4+
dm_control>=0.0.403778684
55
scipy>=1.4.1
66
numpy>=1.18.0
77
pyglet>=1.4.0

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
"accept-rom-license": ["autorom[accept-rom-license]~=0.4.2"],
1515
"box2d": ["box2d-py==2.3.5", "pyglet>=1.4.0"],
1616
"classic_control": ["pyglet>=1.4.0"],
17-
"mujoco": ["mujoco_py>=1.50, <2.0"],
17+
"mujoco": ["dm_control>=0.0.403778684"],
1818
"toy_text": ["scipy>=1.4.1"],
1919
"other": ["lz4>=3.1.0", "opencv-python>=3.0"],
2020
}

tests/envs/spec_list.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
skip_mujoco = not (os.environ.get("MUJOCO_KEY"))
1212
if not skip_mujoco:
1313
try:
14-
import mujoco_py
14+
import dm_control
1515
except ImportError:
1616
skip_mujoco = True
1717

0 commit comments

Comments
 (0)