Skip to content

Commit

Permalink
Allow -1 num_scenarios (#821)
Browse files Browse the repository at this point in the history
* Allow config['num_scenarios']=-1

* Allow config['num_scenarios']=-1
  • Loading branch information
pengzhenghao authored Mar 5, 2025
1 parent 13befc8 commit c29cc37
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 10 deletions.
5 changes: 4 additions & 1 deletion metadrive/envs/base_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,6 @@ def __init__(self, config: dict = None):

# scenarios
self.start_index = 0
self.num_scenarios = self.config["num_scenarios"]

def _post_process_config(self, config):
"""Add more special process to merged config"""
Expand Down Expand Up @@ -690,6 +689,10 @@ def seed(self, seed=None):
def current_seed(self):
return self.engine.global_random_seed

@property
def num_scenarios(self):
return self.config["num_scenarios"]

@property
def observations(self):
"""
Expand Down
4 changes: 3 additions & 1 deletion metadrive/envs/scenario_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
# ===== Scenario Config =====
data_directory=AssetLoader.file_path("nuscenes", unix_style=False),
start_scenario_index=0,

# Set num_scenarios=-1 to load all scenarios in the data directory.
num_scenarios=3,

sequential_seed=False, # Whether to set seed (the index of map) sequentially across episodes
worker_index=0, # Allowing multi-worker sampling with Rllib
num_workers=1, # Allowing multi-worker sampling with Rllib
Expand Down Expand Up @@ -120,7 +123,6 @@ def __init__(self, config=None):
assert self.config["sequential_seed"], \
"If using > 1 workers, you have to allow sequential_seed for consistency!"
self.start_index = self.config["start_scenario_index"]
self.num_scenarios = self.config["num_scenarios"]

def _post_process_config(self, config):
config = super(ScenarioEnv, self)._post_process_config(config)
Expand Down
24 changes: 16 additions & 8 deletions metadrive/manager/scenario_data_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,24 +24,23 @@ def __init__(self):

# for multi-worker
self.worker_index = self.engine.global_config["worker_index"]
self.available_scenario_indices = [
i for i in range(
self.start_scenario_index + self.worker_index, self.start_scenario_index +
self.num_scenarios, self.engine.global_config["num_workers"]
)
]
self._scenarios = {}

# Read summary file first:
self.summary_dict, self.summary_lookup, self.mapping = read_dataset_summary(self.directory)
self.summary_lookup[:self.start_scenario_index] = [None] * self.start_scenario_index
end_idx = self.start_scenario_index + self.num_scenarios
self.summary_lookup[end_idx:] = [None] * (len(self.summary_lookup) - end_idx)

# sort scenario for curriculum training
self.scenario_difficulty = None
self.sort_scenarios()

if self.num_scenarios == -1:
self.num_scenarios = len(self.summary_lookup) - self.start_scenario_index
engine.global_config["num_scenarios"] = self.num_scenarios

end_idx = self.start_scenario_index + self.num_scenarios
self.summary_lookup[end_idx:] = [None] * (len(self.summary_lookup) - end_idx)

# existence check
assert self.start_scenario_index < len(self.summary_lookup), "Insufficient scenarios!"
assert self.start_scenario_index + self.num_scenarios <= len(self.summary_lookup), \
Expand All @@ -55,6 +54,15 @@ def __init__(self):
# stat
self.coverage = [0 for _ in range(self.num_scenarios)]

@property
def available_scenario_indices(self):
return list(
range(
self.start_scenario_index + self.worker_index, self.start_scenario_index + self.num_scenarios,
self.engine.global_config["num_workers"]
)
)

@property
def current_scenario_summary(self):
return self.current_scenario[SD.METADATA]
Expand Down

0 comments on commit c29cc37

Please sign in to comment.