From efd692169feabf92e605afd005fc6ed6f819d2f3 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Mon, 20 Jun 2022 09:59:56 +0200 Subject: [PATCH] SISGraph for_all_nodes more efficient single-threaded --- sisyphus/graph.py | 110 +++++++++++++++++++++++++--------------------- 1 file changed, 59 insertions(+), 51 deletions(-) diff --git a/sisyphus/graph.py b/sisyphus/graph.py index 6f91ed1..4669b45 100644 --- a/sisyphus/graph.py +++ b/sisyphus/graph.py @@ -489,69 +489,77 @@ def for_all_nodes(self, f, nodes=None, bottom_up=False): if path.creator: nodes.append(path.creator) - visited = {} - finished = 0 - if gs.GRAPH_WORKER == 1: - # Run in main thread if only one graph worker is given anyway - def runner(job): - """ - :param Job job: - """ - # make sure all inputs are updated + visited_set = set() + visited_list = [] + queue = list(reversed(nodes)) + while queue: + job = queue.pop(-1) + if id(job) in visited_set: + continue + visited_set.add(id(job)) job._sis_runnable() if bottom_up: - for path in job._sis_inputs: - if path.creator: - runner(path.creator) - f(job) + # execute in reverse order at the end + visited_list.append(job) else: res = f(job) # Stop if function has a not None but false return value - if res is None or res: - for path in job._sis_inputs: - if path.creator: - runner(path.creator) + if res is not None and not res: + continue - else: - pool_lock = threading.Lock() - finished_lock = threading.Lock() - pool = self.pool - - # recursive function to run through tree - def runner(job): - """ - :param Job job: - """ - sis_id = job._sis_id() - with pool_lock: - if sis_id not in visited: - visited[sis_id] = pool.apply_async( - tools.default_handle_exception_interrupt_main_thread(runner_helper), (job,)) - - def runner_helper(job): - """ - :param Job job: - """ - # make sure all inputs are updated - job._sis_runnable() - nonlocal finished + for path in job._sis_inputs: + if path.creator: + if id(path.creator) not in visited_set: + queue.append(path.creator) - if bottom_up: + if bottom_up: + for job in reversed(visited_list): + f(job) + + return visited_set + + visited = {} + finished = 0 + + pool_lock = threading.Lock() + finished_lock = threading.Lock() + pool = self.pool + + # recursive function to run through tree + def runner(job): + """ + :param Job job: + """ + sis_id = job._sis_id() + with pool_lock: + if sis_id not in visited: + visited[sis_id] = pool.apply_async( + tools.default_handle_exception_interrupt_main_thread(runner_helper), (job,)) + + def runner_helper(job): + """ + :param Job job: + """ + # make sure all inputs are updated + job._sis_runnable() + nonlocal finished + + if bottom_up: + for path in job._sis_inputs: + if path.creator: + runner(path.creator) + f(job) + else: + res = f(job) + # Stop if function has a not None but false return value + if res is None or res: for path in job._sis_inputs: if path.creator: runner(path.creator) - f(job) - else: - res = f(job) - # Stop if function has a not None but false return value - if res is None or res: - for path in job._sis_inputs: - if path.creator: - runner(path.creator) - with finished_lock: - finished += 1 + with finished_lock: + finished += 1 for node in nodes: runner(node)