Skip to content

Commit fb89740

Browse files
authored
Merge pull request #175 from clintonsteiner/NREL#166_automatically_download_road_network_if_doesnt_exist
#166 automatically download road network if doesnt exist
2 parents 28e7677 + 9a1cae1 commit fb89740

File tree

4 files changed

+78
-14
lines changed

4 files changed

+78
-14
lines changed

nrel/hive/initialization/initialize_simulation.py

+32-9
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import os
34
import csv
45
import functools as ft
56
import logging
@@ -37,7 +38,6 @@
3738

3839
log = logging.getLogger(__name__)
3940

40-
4141
# All initialization functions must adhere to the following type signature.
4242
# These functions are called to initialize the simulation state and the environment
4343
# and are abstracted such that an external init function can be added to enable custom
@@ -227,7 +227,10 @@ def default_init_functions() -> Iterable[InitFunction]:
227227

228228

229229
def osm_init_function(
230-
config: HiveConfig, simulation_state: SimulationState, environment: Environment
230+
config: HiveConfig,
231+
simulation_state: SimulationState,
232+
environment: Environment,
233+
cache_dir=Path.home(),
231234
) -> Tuple[SimulationState, Environment]:
232235
"""
233236
Initialize an OSMRoadNetwork and add to the simulation
@@ -240,14 +243,34 @@ def osm_init_function(
240243
241244
:raises Exception: from IOErrors parsing the road network
242245
"""
243-
if config.input_config.road_network_file is None:
244-
raise IOError("Must supply a road network file when using the osm_network")
245246

246-
road_network = OSMRoadNetwork.from_file(
247-
sim_h3_resolution=config.sim.sim_h3_resolution,
248-
road_network_file=Path(config.input_config.road_network_file),
249-
default_speed_kmph=config.network.default_speed_kmph,
250-
)
247+
if config.input_config.road_network_file:
248+
road_network = OSMRoadNetwork.from_file(
249+
sim_h3_resolution=config.sim.sim_h3_resolution,
250+
road_network_file=config.input_config.road_network_file,
251+
default_speed_kmph=config.network.default_speed_kmph,
252+
)
253+
elif config.input_config.geofence_file:
254+
try:
255+
import geopandas
256+
except ImportError as e:
257+
raise ImportError(
258+
"Must have geopandas installed if you want to load from geofence file"
259+
) from e
260+
261+
dataframe = geopandas.read_file(config.input_config.geofence_file)
262+
polygon_union = dataframe["geometry"].unary_union
263+
264+
road_network = OSMRoadNetwork.from_polygon(
265+
sim_h3_resolution=config.sim.sim_h3_resolution,
266+
default_speed_kmph=config.network.default_speed_kmph,
267+
polygon=polygon_union,
268+
cache_dir=cache_dir,
269+
)
270+
else:
271+
raise IOError(
272+
"Must supply either a road network or geofence file when using the osm_network"
273+
)
251274

252275
sim_w_osm = simulation_state._replace(road_network=road_network)
253276

nrel/hive/model/roadnetwork/osm/osm_builders.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
def osm_graph_from_polygon(polygon):
1+
from typing import TYPE_CHECKING
2+
3+
4+
def osm_graph_from_polygon(polygon, cache_folder):
25
"""
36
builds a OSM networkx graph using a shapely polygon and the osmnx package
47
"""
@@ -10,6 +13,7 @@ def osm_graph_from_polygon(polygon):
1013
) from e
1114

1215
ox.settings.all_oneway = True
16+
ox.settings.cache_folder = cache_folder
1317

1418
G = ox.graph_from_polygon(polygon, network_type="drive")
1519

nrel/hive/model/roadnetwork/osm/osm_roadnetwork.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def from_polygon(
9999
polygon,
100100
sim_h3_resolution: H3Resolution = 15,
101101
default_speed_kmph: Kmph = 40.0,
102+
cache_dir=Path.home(),
102103
) -> OSMRoadNetwork:
103104
"""
104105
Build an OSMRoadNetwork from a shapely polygon
@@ -107,7 +108,7 @@ def from_polygon(
107108
:param sim_h3_resolution: The h3 resolution of the simulation
108109
:param default_speed_kmph: The network will fill in missing speed values with this
109110
"""
110-
graph = osm_graph_from_polygon(polygon)
111+
graph = osm_graph_from_polygon(polygon, cache_dir)
111112
return OSMRoadNetwork(graph, sim_h3_resolution, default_speed_kmph)
112113

113114
@classmethod

tests/test_initialize_simulation.py

+39-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
1-
from unittest import TestCase
1+
import unittest
22

33
from nrel.hive.config.network import Network
4-
from nrel.hive.initialization.initialize_simulation import initialize, default_init_functions
4+
from nrel.hive.initialization.load import load_simulation
5+
from nrel.hive.initialization.initialize_simulation import (
6+
initialize,
7+
default_init_functions,
8+
osm_init_function,
9+
)
510
from nrel.hive.initialization.initialize_simulation_with_sampling import (
611
initialize_simulation_with_sampling,
712
)
@@ -10,7 +15,7 @@
1015
import nrel.hive.state.simulation_state.simulation_state_ops as sso
1116

1217

13-
class TestInitializeSimulation(TestCase):
18+
class TestInitializeSimulation(unittest.TestCase):
1419
def test_initialize_simulation(self):
1520
conf = mock_config().suppress_logging()
1621

@@ -79,3 +84,34 @@ def test_initialize_simulation_with_sampling(self):
7984
self.assertEqual(len(sim.vehicles), 20, "should have loaded 20 vehicles")
8085
self.assertEqual(len(sim.stations), 4, "should have loaded 4 stations")
8186
self.assertEqual(len(sim.bases), 2, "should have loaded 2 bases")
87+
88+
@unittest.skip("makes API call to OpenStreetMaps via osmnx")
89+
def test_initialize_simulation_load_from_geofence_file(self):
90+
"""Move this to long running if it gets to be out of hand"""
91+
conf = (
92+
mock_config()
93+
._replace(
94+
network=Network(
95+
network_type="osm_network",
96+
default_speed_kmph=40,
97+
)
98+
)
99+
.suppress_logging()
100+
)
101+
102+
new_input = conf.input_config._replace(
103+
geofence_file=Path(
104+
resource_filename(
105+
"nrel.hive.resources.scenarios.denver_downtown.geofence",
106+
"downtown_denver.geojson",
107+
)
108+
),
109+
road_network_file=None,
110+
)
111+
112+
conf = conf._replace(input_config=new_input)
113+
114+
runnerPayload = load_simulation(
115+
config=conf,
116+
)
117+
self.assertIsNotNone(runnerPayload.s.road_network)

0 commit comments

Comments
 (0)