@@ -264,11 +264,12 @@ async def mcts_search(self, websocket=None) -> Optional[LATSNode]:
264264 print (f"{ GREEN } Step 1: Node Selection{ RESET } " )
265265 await self .websocket_step_start (step = 1 , step_name = "node_selection" , websocket = websocket )
266266 selected_node = await self .node_selection (self .root_node , websocket )
267- tree_data = self ._get_tree_data ()
268- if websocket :
269- await self .websocket_tree_update (type = "tree_update_node_selection" , websocket = websocket , tree_data = tree_data )
270- else :
271- print_entire_tree (self .root_node )
267+ # await self.websocket_node_selection(selected_node, websocket=websocket)
268+ # tree_data = self._get_tree_data()
269+ # if websocket:
270+ # await self.websocket_tree_update(type="tree_update_node_selection", websocket=websocket, tree_data=tree_data)
271+ # else:
272+ # print_entire_tree(self.root_node)
272273
273274 if selected_node is None :
274275 logger .warning ("All paths lead to terminal nodes. Ending search." )
@@ -338,10 +339,28 @@ async def mcts_search(self, websocket=None) -> Optional[LATSNode]:
338339 for node in path :
339340 old_value = node .value
340341 node .visits += 1
341- node .value = (node .value * (node .visits - 1 ) + score ) / node .visits
342+ node .value += (score - node .value ) / node .visits
343+ # consiste with lats backpropagation
344+ #node.value = (node.value * (node.visits - 1) + score) / node.visits
342345 print (f"Node { node .action } :" )
343346 print (f" Visits: { node .visits } " )
344347 print (f" Value: { old_value :.3f} -> { node .value :.3f} " )
348+ # add websocket information, just use websocket here
349+ # if websocket:
350+ # await websocket.send_json({
351+ # "type": "backpropagation",
352+ # "node_id": id(node),
353+ # "node_parent_id": id(node.parent),
354+ # "node_action": node.action,
355+ # "node_value": node.value,
356+ # "node_visits": node.visits,
357+ # "node_old_value": old_value,
358+ # "node_description": node.natural_language_description,
359+ # })
360+
361+ tree_data = self ._get_tree_data ()
362+ print_entire_tree (self .root_node )
363+ print (tree_data )
345364 if websocket :
346365 await self .websocket_tree_update (type = "tree_update_node_backpropagation" , websocket = websocket , tree_data = tree_data )
347366 else :
0 commit comments