@@ -118,6 +118,13 @@ def reverse_graph(graph: GraphT[NodeT]) -> GraphT[NodeT]:
118
118
119
119
# {{{ a_star
120
120
121
+ @dataclass (frozen = True )
122
+ class _AStarNode (Generic [NodeT ]):
123
+ state : NodeT
124
+ parent : _AStarNode [NodeT ] | None
125
+ path_cost : float | int
126
+
127
+
121
128
def a_star (
122
129
initial_state : NodeT , goal_state : NodeT , neighbor_map : GraphT [NodeT ],
123
130
estimate_remaining_cost : Callable [[NodeT ], float ] | None = None ,
@@ -135,19 +142,11 @@ def estimate_remaining_cost(x: NodeT) -> float:
135
142
return 1
136
143
return 0
137
144
138
- class AStarNode :
139
- __slots__ = ["parent" , "path_cost" , "state" ]
140
-
141
- def __init__ (self , state : NodeT , parent : Any , path_cost : float ) -> None :
142
- self .state = state
143
- self .parent = parent
144
- self .path_cost = path_cost
145
-
146
145
inf = float ("inf" )
147
146
init_remcost = estimate_remaining_cost (initial_state )
148
147
assert init_remcost != inf
149
148
150
- queue = [(init_remcost , AStarNode (initial_state , parent = None , path_cost = 0 ))]
149
+ queue = [(init_remcost , _AStarNode (initial_state , parent = None , path_cost = 0 ))]
151
150
visited_states = set ()
152
151
153
152
while queue :
@@ -156,7 +155,7 @@ def __init__(self, state: NodeT, parent: Any, path_cost: float) -> None:
156
155
157
156
if top .state == goal_state :
158
157
result = []
159
- it : AStarNode | None = top
158
+ it : _AStarNode [ NodeT ] | None = top
160
159
while it is not None :
161
160
result .append (it .state )
162
161
it = it .parent
@@ -174,7 +173,7 @@ def __init__(self, state: NodeT, parent: Any, path_cost: float) -> None:
174
173
estimated_path_cost = top .path_cost + step_cost + remaining_cost
175
174
heappush (queue ,
176
175
(estimated_path_cost ,
177
- AStarNode (state , top , path_cost = top .path_cost + step_cost )))
176
+ _AStarNode (state , top , path_cost = top .path_cost + step_cost )))
178
177
179
178
raise RuntimeError ("no solution" )
180
179
0 commit comments