383 lines
14 KiB
Python
383 lines
14 KiB
Python
import os
|
||
import random
|
||
import shutil
|
||
import time
|
||
from copy import deepcopy
|
||
from dataclasses import asdict, dataclass
|
||
from typing import Callable
|
||
|
||
import graphviz
|
||
import numpy as np
|
||
from matplotlib import pyplot as plt
|
||
|
||
from .chromosome import Chromosome
|
||
from .node import Node
|
||
from .types import Fitnesses, Population
|
||
|
||
type FitnessFn = Callable[[Chromosome], float]
|
||
|
||
type InitializePopulationFn = Callable[[int], Population]
|
||
type CrossoverFn = Callable[[Chromosome, Chromosome], tuple[Chromosome, Chromosome]]
|
||
type MutationFn = Callable[[Chromosome], Chromosome]
|
||
type SelectionFn = Callable[[Population, Fitnesses], Population]
|
||
|
||
|
||
@dataclass(frozen=True)
|
||
class GARunConfig:
|
||
fitness_func: FitnessFn
|
||
crossover_fn: CrossoverFn
|
||
mutation_fn: MutationFn
|
||
selection_fn: SelectionFn
|
||
init_population: Population
|
||
pc: float # вероятность кроссинговера
|
||
pm: float # вероятность мутации
|
||
max_generations: int # максимальное количество поколений
|
||
elitism: int = (
|
||
0 # сколько лучших особей перенести без изменения в следующее поколение
|
||
)
|
||
max_best_repetitions: int | None = (
|
||
None # остановка при повторении лучшего результата
|
||
)
|
||
seed: int | None = None # seed для генератора случайных чисел
|
||
minimize: bool = True # если True, ищем минимум вместо максимума
|
||
save_generations: list[int] | None = (
|
||
None # индексы поколений для сохранения графиков
|
||
)
|
||
results_dir: str = "results" # папка для сохранения графиков
|
||
fitness_avg_threshold: float | None = (
|
||
None # порог среднего значения фитнес функции для остановки
|
||
)
|
||
best_value_threshold: float | None = (
|
||
None # остановка при достижении значения фитнеса лучше заданного
|
||
)
|
||
log_every_generation: bool = False # логировать каждое поколение
|
||
|
||
def save(self, filename: str = "GARunConfig.txt"):
|
||
"""Сохраняет конфиг в results_dir."""
|
||
os.makedirs(self.results_dir, exist_ok=True)
|
||
path = os.path.join(self.results_dir, filename)
|
||
|
||
with open(path, "w", encoding="utf-8") as f:
|
||
for k, v in asdict(self).items():
|
||
f.write(f"{k}: {v}\n")
|
||
|
||
|
||
@dataclass(frozen=True)
|
||
class Generation:
|
||
number: int
|
||
best: Chromosome
|
||
best_fitness: float
|
||
avg_fitness: float
|
||
population: Population
|
||
fitnesses: Fitnesses
|
||
|
||
|
||
@dataclass(frozen=True)
|
||
class GARunResult:
|
||
generations_count: int
|
||
best_generation: Generation
|
||
history: list[Generation]
|
||
time_ms: float
|
||
|
||
def save(self, path: str, filename: str = "GARunResult.txt"):
|
||
"""Сохраняет конфиг в results_dir."""
|
||
os.makedirs(path, exist_ok=True)
|
||
path = os.path.join(path, filename)
|
||
|
||
with open(path, "w", encoding="utf-8") as f:
|
||
for k, v in asdict(self).items():
|
||
if k == "history":
|
||
continue
|
||
if k == "best_generation":
|
||
f.write(
|
||
f"{k}: Number: {v['number']}, Best Fitness: {v['best_fitness']}, Best: {v['best']}\n"
|
||
)
|
||
else:
|
||
f.write(f"{k}: {v}\n")
|
||
|
||
|
||
def crossover(
|
||
population: Population,
|
||
pc: float,
|
||
crossover_fn: CrossoverFn,
|
||
) -> Population:
|
||
"""Оператор кроссинговера (скрещивания) выполняется с заданной вероятностью pc.
|
||
|
||
Две хромосомы (родители) выбираются случайно из промежуточной популяции.
|
||
|
||
Если популяция нечетного размера, то последняя хромосома скрещивается со случайной
|
||
другой хромосомой из популяции. В таком случае одна из хромосом может поучаствовать
|
||
в кроссовере дважды.
|
||
"""
|
||
# Создаем копию популяции и перемешиваем её для случайного выбора пар
|
||
shuffled_population = population.copy()
|
||
random.shuffle(shuffled_population)
|
||
|
||
next_population = []
|
||
pop_size = len(shuffled_population)
|
||
|
||
for i in range(0, pop_size, 2):
|
||
p1 = shuffled_population[i]
|
||
p2 = shuffled_population[(i + 1) % pop_size]
|
||
if np.random.random() <= pc:
|
||
p1, p2 = crossover_fn(p1, p2)
|
||
next_population.append(p1)
|
||
next_population.append(p2)
|
||
|
||
return next_population[:pop_size]
|
||
|
||
|
||
def mutation(
|
||
population: Population, pm: float, gen_num: int, mutation_fn: MutationFn
|
||
) -> Population:
|
||
"""Мутация происходит с вероятностью pm."""
|
||
next_population = []
|
||
for chrom in population:
|
||
next_population.append(
|
||
mutation_fn(chrom) if np.random.random() <= pm else chrom
|
||
)
|
||
return next_population
|
||
|
||
|
||
def clear_results_directory(results_dir: str) -> None:
|
||
"""Очищает папку с результатами перед началом эксперимента."""
|
||
if os.path.exists(results_dir):
|
||
shutil.rmtree(results_dir)
|
||
os.makedirs(results_dir, exist_ok=True)
|
||
|
||
|
||
def eval_population(population: Population, fitness_func: FitnessFn) -> Fitnesses:
|
||
return np.array([fitness_func(chrom) for chrom in population])
|
||
|
||
|
||
def render_tree_to_graphviz(
|
||
node: Node, graph: graphviz.Digraph, node_id: str = "0"
|
||
) -> None:
|
||
"""Рекурсивно добавляет узлы дерева в graphviz граф."""
|
||
graph.node(node_id, label=node.value.name)
|
||
|
||
for i, child in enumerate(node.children):
|
||
child_id = f"{node_id}_{i}"
|
||
render_tree_to_graphviz(child, graph, child_id)
|
||
graph.edge(node_id, child_id)
|
||
|
||
|
||
def save_generation(
|
||
generation: Generation, history: list[Generation], config: GARunConfig
|
||
) -> None:
|
||
"""Сохраняет визуализацию лучшей хромосомы поколения в виде дерева."""
|
||
os.makedirs(config.results_dir, exist_ok=True)
|
||
|
||
# Создаем граф для визуализации дерева
|
||
dot = graphviz.Digraph(comment=f"Generation {generation.number}")
|
||
dot.attr(rankdir="TB") # Top to Bottom direction
|
||
dot.attr("node", shape="circle", style="filled", fillcolor="lightblue")
|
||
|
||
# Добавляем заголовок
|
||
depth = generation.best.root.get_depth()
|
||
title = (
|
||
f"Поколение #{generation.number}\\n"
|
||
f"Лучшая особь: {generation.best_fitness:.4f}\\n"
|
||
f"Глубина дерева: {depth}"
|
||
)
|
||
dot.attr(label=title, labelloc="t", fontsize="14")
|
||
|
||
# Рендерим дерево
|
||
render_tree_to_graphviz(generation.best.root, dot)
|
||
|
||
# Сохраняем
|
||
filename = f"generation_{generation.number:03d}"
|
||
filepath = os.path.join(config.results_dir, filename)
|
||
dot.render(filepath, format="png", cleanup=True)
|
||
|
||
|
||
def genetic_algorithm(config: GARunConfig) -> GARunResult:
|
||
if config.seed is not None:
|
||
random.seed(config.seed)
|
||
np.random.seed(config.seed)
|
||
|
||
if config.save_generations:
|
||
clear_results_directory(config.results_dir)
|
||
|
||
population = config.init_population
|
||
|
||
start = time.perf_counter()
|
||
history: list[Generation] = []
|
||
best: Generation | None = None
|
||
|
||
generation_number = 1
|
||
best_repetitions = 0
|
||
|
||
while True:
|
||
# Вычисляем фитнес для всех особей в популяции
|
||
fitnesses = eval_population(population, config.fitness_func)
|
||
|
||
# Сохраняем лучших особей для переноса в следующее поколение
|
||
elites: list[Chromosome] = []
|
||
if config.elitism:
|
||
elites = deepcopy(
|
||
[
|
||
population[i]
|
||
for i in sorted(
|
||
range(len(fitnesses)),
|
||
key=lambda i: fitnesses[i],
|
||
reverse=not config.minimize,
|
||
)
|
||
][: config.elitism]
|
||
)
|
||
|
||
# Находим лучшую особь в поколении
|
||
best_index = (
|
||
int(np.argmin(fitnesses)) if config.minimize else int(np.argmax(fitnesses))
|
||
)
|
||
|
||
# Добавляем эпоху в историю
|
||
current = Generation(
|
||
number=generation_number,
|
||
best=population[best_index],
|
||
best_fitness=fitnesses[best_index],
|
||
avg_fitness=float(np.mean(fitnesses)),
|
||
# population=deepcopy(population),
|
||
population=[],
|
||
# fitnesses=deepcopy(fitnesses),
|
||
fitnesses=np.array([]),
|
||
)
|
||
history.append(current)
|
||
|
||
if config.log_every_generation:
|
||
print(
|
||
f"Generation #{generation_number} best: {current.best_fitness},"
|
||
f" avg: {np.mean(fitnesses)}"
|
||
)
|
||
|
||
# Обновляем лучшую эпоху
|
||
if (
|
||
best is None
|
||
or (config.minimize and current.best_fitness < best.best_fitness)
|
||
or (not config.minimize and current.best_fitness > best.best_fitness)
|
||
):
|
||
best = current
|
||
|
||
# Проверка критериев остановки
|
||
stop_algorithm = False
|
||
|
||
if generation_number >= config.max_generations:
|
||
stop_algorithm = True
|
||
|
||
if config.max_best_repetitions is not None and generation_number > 1:
|
||
if history[-2].best_fitness == current.best_fitness:
|
||
best_repetitions += 1
|
||
|
||
if best_repetitions == config.max_best_repetitions:
|
||
stop_algorithm = True
|
||
else:
|
||
best_repetitions = 0
|
||
|
||
if config.best_value_threshold is not None:
|
||
if (
|
||
config.minimize and current.best_fitness < config.best_value_threshold
|
||
) or (
|
||
not config.minimize
|
||
and current.best_fitness > config.best_value_threshold
|
||
):
|
||
stop_algorithm = True
|
||
|
||
if config.fitness_avg_threshold is not None:
|
||
mean_fitness = np.mean(fitnesses)
|
||
if (config.minimize and mean_fitness < config.fitness_avg_threshold) or (
|
||
not config.minimize and mean_fitness > config.fitness_avg_threshold
|
||
):
|
||
stop_algorithm = True
|
||
|
||
# Сохраняем указанные поколения и последнее поколение
|
||
if config.save_generations and (
|
||
stop_algorithm or generation_number in config.save_generations
|
||
):
|
||
save_generation(current, history, config)
|
||
|
||
if stop_algorithm:
|
||
break
|
||
|
||
# селекция (для минимума инвертируем знак)
|
||
parents = config.selection_fn(
|
||
population, fitnesses if not config.minimize else -fitnesses
|
||
)
|
||
|
||
# кроссинговер попарно
|
||
next_population = crossover(parents, config.pc, config.crossover_fn)
|
||
|
||
# мутация
|
||
next_population = mutation(
|
||
next_population,
|
||
config.pm,
|
||
generation_number,
|
||
config.mutation_fn,
|
||
)
|
||
|
||
# Вставляем элиту в новую популяцию
|
||
population = next_population[: len(population) - config.elitism] + elites
|
||
|
||
generation_number += 1
|
||
|
||
end = time.perf_counter()
|
||
|
||
assert best is not None, "Best was never set"
|
||
result = GARunResult(
|
||
len(history),
|
||
best,
|
||
history,
|
||
(end - start) * 1000.0,
|
||
)
|
||
|
||
# Автоматически строим графики истории фитнеса
|
||
if config.save_generations:
|
||
plot_fitness_history(result, save_dir=config.results_dir)
|
||
|
||
return result
|
||
|
||
|
||
def plot_fitness_history(result: GARunResult, save_dir: str | None = None) -> None:
|
||
"""Рисует графики изменения лучших и средних значений фитнеса по поколениям.
|
||
|
||
Создает два отдельных графика:
|
||
- fitness_best.png - график лучших значений
|
||
- fitness_avg.png - график средних значений
|
||
"""
|
||
generations = [gen.number for gen in result.history]
|
||
best_fitnesses = [gen.best_fitness for gen in result.history]
|
||
avg_fitnesses = [gen.avg_fitness for gen in result.history]
|
||
|
||
# График лучших значений
|
||
fig_best, ax_best = plt.subplots(figsize=(10, 6))
|
||
ax_best.plot(generations, best_fitnesses, linewidth=2, color="blue")
|
||
ax_best.set_xlabel("Поколение", fontsize=12)
|
||
ax_best.set_ylabel("Лучшее значение фитнес-функции", fontsize=12)
|
||
ax_best.set_title("Лучшее значение фитнеса по поколениям", fontsize=14)
|
||
ax_best.grid(True, alpha=0.3)
|
||
|
||
if save_dir:
|
||
best_path = os.path.join(save_dir, "fitness_best.png")
|
||
fig_best.savefig(best_path, dpi=150, bbox_inches="tight")
|
||
print(f"График лучших значений сохранен в {best_path}")
|
||
else:
|
||
plt.show()
|
||
|
||
plt.close(fig_best)
|
||
|
||
# График средних значений
|
||
fig_avg, ax_avg = plt.subplots(figsize=(10, 6))
|
||
ax_avg.plot(generations, avg_fitnesses, linewidth=2, color="orange")
|
||
ax_avg.set_xlabel("Поколение", fontsize=12)
|
||
ax_avg.set_ylabel("Среднее значение фитнес-функции", fontsize=12)
|
||
ax_avg.set_title("Среднее значение фитнеса по поколениям", fontsize=14)
|
||
ax_avg.grid(True, alpha=0.3)
|
||
|
||
if save_dir:
|
||
avg_path = os.path.join(save_dir, "fitness_avg.png")
|
||
fig_avg.savefig(avg_path, dpi=150, bbox_inches="tight")
|
||
print(f"График средних значений сохранен в {avg_path}")
|
||
else:
|
||
plt.show()
|
||
|
||
plt.close(fig_avg)
|