Files
genetic-algorithms/lab4/gp/node.py
2025-11-07 00:11:02 +03:00

119 lines
4.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import random
from typing import Sequence
from .primitive import Primitive
from .types import Context, Value
class Node:
def __init__(self, value: Primitive):
self.value = value
self.parent: Node | None = None
self.children: list[Node] = []
def add_child(self, child: Node) -> None:
self.children.append(child)
child.parent = self
def remove_child(self, child: Node) -> None:
self.children.remove(child)
child.parent = None
def replace_child(self, old_child: Node, new_child: Node) -> None:
self.children[self.children.index(old_child)] = new_child
old_child.parent = None
new_child.parent = self
def remove_children(self) -> None:
for child in self.children:
child.parent = None
self.children = []
def copy_subtree(self) -> Node:
node = Node(self.value)
for child in self.children:
node.add_child(child.copy_subtree())
return node
def list_nodes(self) -> list[Node]:
"""Список всех узлов поддерева, начиная с текущего (aka depth-first-search)."""
nodes: list[Node] = [self]
for child in self.children:
nodes.extend(child.list_nodes())
return nodes
def prune(self, terminals: Sequence[Primitive], max_depth: int) -> None:
"""Усечение поддерева до заданной глубины.
Заменяет операции на глубине max_depth на случайные терминалы.
"""
def prune_recursive(node: Node, current_depth: int) -> None:
if node.value.arity == 0: # Терминалы остаются без изменений
return
if current_depth >= max_depth:
node.remove_children()
node.value = random.choice(terminals)
return
for child in node.children:
prune_recursive(child, current_depth + 1)
prune_recursive(self, 1)
def get_depth(self) -> int:
"""Вычисляет глубину поддерева, начиная с текущего узла."""
return (
max(child.get_depth() for child in self.children) + 1
if self.children
else 1
)
def get_size(self) -> int:
"""Вычисляет размер поддерева, начиная с текущего узла."""
return sum(child.get_size() for child in self.children) + 1
def get_level(self) -> int:
"""Вычисляет уровень узла в дереве (расстояние от корня). Корень имеет уровень 1."""
return self.parent.get_level() + 1 if self.parent else 1
def eval(self, context: Context) -> Value:
return self.value.eval(
[child.eval(context) for child in self.children], context
)
def __str__(self) -> str:
"""Рекурсивный перевод древовидного вида формулы в строку в инфиксной форме."""
if self.value.arity == 0:
return self.value.name
if self.value.arity == 2:
return f"({self.children[0]} {self.value.name} {self.children[1]})"
return f"{self.value.name}({', '.join(str(child) for child in self.children)})"
def to_str_tree(self, prefix="", is_last: bool = True) -> str:
"""Строковое представление древовидной структуры."""
lines = prefix + ("└── " if is_last else "├── ") + self.value.name + "\n"
child_prefix = prefix + (" " if is_last else "")
for i, child in enumerate(self.children):
is_child_last = i == len(self.children) - 1
lines += child.to_str_tree(child_prefix, is_child_last)
return lines
def swap_subtrees(a: Node, b: Node) -> None:
if a.parent is None or b.parent is None:
raise ValueError("Нельзя обменять корни деревьев")
# Сохраняем ссылки на родителей
a_parent = a.parent
b_parent = b.parent
i = a_parent.children.index(a)
j = b_parent.children.index(b)
a_parent.children[i], b_parent.children[j] = b, a
a.parent, b.parent = b_parent, a_parent