Skip to content

Commit d7d38a3

Browse files
committed
A* node: use dataclass (ruff)
1 parent 658db0c commit d7d38a3

File tree

1 file changed

+10
-11
lines changed

1 file changed

+10
-11
lines changed

pytools/graph.py

+10-11
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,13 @@ def reverse_graph(graph: GraphT[NodeT]) -> GraphT[NodeT]:
118118

119119
# {{{ a_star
120120

121+
@dataclass(frozen=True)
122+
class _AStarNode(Generic[NodeT]):
123+
state: NodeT
124+
parent: _AStarNode[NodeT] | None
125+
path_cost: float | int
126+
127+
121128
def a_star(
122129
initial_state: NodeT, goal_state: NodeT, neighbor_map: GraphT[NodeT],
123130
estimate_remaining_cost: Callable[[NodeT], float] | None = None,
@@ -135,19 +142,11 @@ def estimate_remaining_cost(x: NodeT) -> float:
135142
return 1
136143
return 0
137144

138-
class AStarNode:
139-
__slots__ = ["parent", "path_cost", "state"]
140-
141-
def __init__(self, state: NodeT, parent: Any, path_cost: float) -> None:
142-
self.state = state
143-
self.parent = parent
144-
self.path_cost = path_cost
145-
146145
inf = float("inf")
147146
init_remcost = estimate_remaining_cost(initial_state)
148147
assert init_remcost != inf
149148

150-
queue = [(init_remcost, AStarNode(initial_state, parent=None, path_cost=0))]
149+
queue = [(init_remcost, _AStarNode(initial_state, parent=None, path_cost=0))]
151150
visited_states = set()
152151

153152
while queue:
@@ -156,7 +155,7 @@ def __init__(self, state: NodeT, parent: Any, path_cost: float) -> None:
156155

157156
if top.state == goal_state:
158157
result = []
159-
it: AStarNode | None = top
158+
it: _AStarNode[NodeT] | None = top
160159
while it is not None:
161160
result.append(it.state)
162161
it = it.parent
@@ -174,7 +173,7 @@ def __init__(self, state: NodeT, parent: Any, path_cost: float) -> None:
174173
estimated_path_cost = top.path_cost+step_cost+remaining_cost
175174
heappush(queue,
176175
(estimated_path_cost,
177-
AStarNode(state, top, path_cost=top.path_cost + step_cost)))
176+
_AStarNode(state, top, path_cost=top.path_cost + step_cost)))
178177

179178
raise RuntimeError("no solution")
180179

0 commit comments

Comments
 (0)