#!/usr/bin/env python3 """Generate plots from GA history for the course work report.""" import json from pathlib import Path import matplotlib import matplotlib.pyplot as plt matplotlib.rcParams.update( { "font.family": "DejaVu Sans", "axes.grid": True, "grid.alpha": 0.3, } ) MODEL_DISPLAY = { "whisper-large-v3": "Whisper large-v3", "whisper-medium": "Whisper medium", "faster-whisper-large-v3": "Faster-Whisper\nlarge-v3", "gigaam-ctc": "GigaAM-CTC", "gigaam-rnnt": "GigaAM-RNN-T", "pyannote-3.1": "pyannote 3.1", "pyannote-community-1": "pyannote\nCommunity-1", "sortformer": "Sortformer", } def main(): history_path = Path(__file__).parent / "history.json" data = json.loads(history_path.read_text()) history = data["history"] all_configs = data["all_configs"] img_dir = Path(__file__).parent.parent / "report" / "img" img_dir.mkdir(parents=True, exist_ok=True) plot_convergence(history, img_dir) plot_wer_der_scatter(all_configs, img_dir) plot_model_frequency(all_configs, img_dir) def plot_convergence(history: list[dict], img_dir: Path): gens = [h["generation"] for h in history] best = [-h["best_fitness"] for h in history] mean = [-h["mean_fitness"] for h in history] fig, ax = plt.subplots(figsize=(7, 4.5)) ax.plot(gens, best, "b-o", markersize=4, linewidth=1.5, label="Лучшая особь") ax.plot( gens, mean, "r--s", markersize=3, linewidth=1.2, label="Среднее по популяции" ) ax.set_xlabel("Поколение", fontsize=12) ax.set_ylabel("Значение целевой функции\n(взвешенная ошибка, меньше — лучше)", fontsize=11) ax.legend(fontsize=11) fig.tight_layout() fig.savefig(img_dir / "convergence.png", dpi=150) plt.close(fig) print(f"Saved {img_dir / 'convergence.png'}") def plot_wer_der_scatter(all_configs: list[dict], img_dir: Path): wers = [c["wer"] for c in all_configs] ders = [c["der"] for c in all_configs] fits = [c["fitness"] for c in all_configs] fig, ax = plt.subplots(figsize=(7, 5.5)) sc = ax.scatter( wers, ders, c=fits, cmap="RdYlGn", alpha=0.7, edgecolors="gray", linewidth=0.5, s=40, ) best = max(all_configs, key=lambda c: c["fitness"]) ax.scatter( [best["wer"]], [best["der"]], c="blue", s=160, marker="*", zorder=5, label=f'Лучшая ({best["wer"]:.1f}%, {best["der"]:.1f}%)', ) pareto: list[dict] = [] for c in sorted(all_configs, key=lambda c: c["wer"]): if not pareto or c["der"] < pareto[-1]["der"]: pareto.append(c) if len(pareto) > 1: ax.plot( [c["wer"] for c in pareto], [c["der"] for c in pareto], "k--", alpha=0.5, linewidth=1.2, label="Парето-фронт", ) ax.set_xlabel("WER, %", fontsize=12) ax.set_ylabel("DER, %", fontsize=12) ax.legend(fontsize=11) cbar = fig.colorbar(sc, ax=ax) cbar.set_label("Фитнес", fontsize=11) fig.tight_layout() fig.savefig(img_dir / "wer_der_scatter.png", dpi=150) plt.close(fig) print(f"Saved {img_dir / 'wer_der_scatter.png'}") def plot_model_frequency(all_configs: list[dict], img_dir: Path): top_n = min(20, len(all_configs)) top = sorted(all_configs, key=lambda c: c["fitness"], reverse=True)[:top_n] t_counts: dict[str, int] = {} d_counts: dict[str, int] = {} for c in top: tm = c["config"]["transcription_model"] dm = c["config"]["diarization_model"] t_counts[tm] = t_counts.get(tm, 0) + 1 d_counts[dm] = d_counts.get(dm, 0) + 1 fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(11, 4.5)) t_names = sorted(t_counts.keys(), key=lambda n: t_counts[n], reverse=True) t_labels = [MODEL_DISPLAY.get(n, n) for n in t_names] t_vals = [t_counts[n] for n in t_names] ax1.barh(t_labels, t_vals, color="steelblue") ax1.set_xlabel(f"Количество в топ-{top_n}", fontsize=11) ax1.set_title("Модели транскрибации", fontsize=12) ax1.invert_yaxis() d_names = sorted(d_counts.keys(), key=lambda n: d_counts[n], reverse=True) d_labels = [MODEL_DISPLAY.get(n, n) for n in d_names] d_vals = [d_counts[n] for n in d_names] ax2.barh(d_labels, d_vals, color="coral") ax2.set_xlabel(f"Количество в топ-{top_n}", fontsize=11) ax2.set_title("Модели диаризации", fontsize=12) ax2.invert_yaxis() fig.tight_layout() fig.savefig(img_dir / "model_frequency.png", dpi=150) plt.close(fig) print(f"Saved {img_dir / 'model_frequency.png'}") if __name__ == "__main__": main()