га
This commit is contained in:
154
ga/generate_plots.py
Normal file
154
ga/generate_plots.py
Normal file
@@ -0,0 +1,154 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Generate plots from GA history for the course work report."""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import matplotlib
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
matplotlib.rcParams.update(
|
||||
{
|
||||
"font.family": "DejaVu Sans",
|
||||
"axes.grid": True,
|
||||
"grid.alpha": 0.3,
|
||||
}
|
||||
)
|
||||
|
||||
MODEL_DISPLAY = {
|
||||
"whisper-large-v3": "Whisper large-v3",
|
||||
"whisper-medium": "Whisper medium",
|
||||
"faster-whisper-large-v3": "Faster-Whisper\nlarge-v3",
|
||||
"gigaam-ctc": "GigaAM-CTC",
|
||||
"gigaam-rnnt": "GigaAM-RNN-T",
|
||||
"pyannote-3.1": "pyannote 3.1",
|
||||
"pyannote-community-1": "pyannote\nCommunity-1",
|
||||
"sortformer": "Sortformer",
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
history_path = Path(__file__).parent / "history.json"
|
||||
data = json.loads(history_path.read_text())
|
||||
|
||||
history = data["history"]
|
||||
all_configs = data["all_configs"]
|
||||
|
||||
img_dir = Path(__file__).parent.parent / "report" / "img"
|
||||
img_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
plot_convergence(history, img_dir)
|
||||
plot_wer_der_scatter(all_configs, img_dir)
|
||||
plot_model_frequency(all_configs, img_dir)
|
||||
|
||||
|
||||
def plot_convergence(history: list[dict], img_dir: Path):
|
||||
gens = [h["generation"] for h in history]
|
||||
best = [-h["best_fitness"] for h in history]
|
||||
mean = [-h["mean_fitness"] for h in history]
|
||||
|
||||
fig, ax = plt.subplots(figsize=(7, 4.5))
|
||||
ax.plot(gens, best, "b-o", markersize=4, linewidth=1.5, label="Лучшая особь")
|
||||
ax.plot(
|
||||
gens, mean, "r--s", markersize=3, linewidth=1.2, label="Среднее по популяции"
|
||||
)
|
||||
ax.set_xlabel("Поколение", fontsize=12)
|
||||
ax.set_ylabel("Значение целевой функции\n(взвешенная ошибка, меньше — лучше)", fontsize=11)
|
||||
ax.legend(fontsize=11)
|
||||
fig.tight_layout()
|
||||
fig.savefig(img_dir / "convergence.png", dpi=150)
|
||||
plt.close(fig)
|
||||
print(f"Saved {img_dir / 'convergence.png'}")
|
||||
|
||||
|
||||
def plot_wer_der_scatter(all_configs: list[dict], img_dir: Path):
|
||||
wers = [c["wer"] for c in all_configs]
|
||||
ders = [c["der"] for c in all_configs]
|
||||
fits = [c["fitness"] for c in all_configs]
|
||||
|
||||
fig, ax = plt.subplots(figsize=(7, 5.5))
|
||||
sc = ax.scatter(
|
||||
wers,
|
||||
ders,
|
||||
c=fits,
|
||||
cmap="RdYlGn",
|
||||
alpha=0.7,
|
||||
edgecolors="gray",
|
||||
linewidth=0.5,
|
||||
s=40,
|
||||
)
|
||||
|
||||
best = max(all_configs, key=lambda c: c["fitness"])
|
||||
ax.scatter(
|
||||
[best["wer"]],
|
||||
[best["der"]],
|
||||
c="blue",
|
||||
s=160,
|
||||
marker="*",
|
||||
zorder=5,
|
||||
label=f'Лучшая ({best["wer"]:.1f}%, {best["der"]:.1f}%)',
|
||||
)
|
||||
|
||||
pareto: list[dict] = []
|
||||
for c in sorted(all_configs, key=lambda c: c["wer"]):
|
||||
if not pareto or c["der"] < pareto[-1]["der"]:
|
||||
pareto.append(c)
|
||||
if len(pareto) > 1:
|
||||
ax.plot(
|
||||
[c["wer"] for c in pareto],
|
||||
[c["der"] for c in pareto],
|
||||
"k--",
|
||||
alpha=0.5,
|
||||
linewidth=1.2,
|
||||
label="Парето-фронт",
|
||||
)
|
||||
|
||||
ax.set_xlabel("WER, %", fontsize=12)
|
||||
ax.set_ylabel("DER, %", fontsize=12)
|
||||
ax.legend(fontsize=11)
|
||||
cbar = fig.colorbar(sc, ax=ax)
|
||||
cbar.set_label("Фитнес", fontsize=11)
|
||||
fig.tight_layout()
|
||||
fig.savefig(img_dir / "wer_der_scatter.png", dpi=150)
|
||||
plt.close(fig)
|
||||
print(f"Saved {img_dir / 'wer_der_scatter.png'}")
|
||||
|
||||
|
||||
def plot_model_frequency(all_configs: list[dict], img_dir: Path):
|
||||
top_n = min(20, len(all_configs))
|
||||
top = sorted(all_configs, key=lambda c: c["fitness"], reverse=True)[:top_n]
|
||||
|
||||
t_counts: dict[str, int] = {}
|
||||
d_counts: dict[str, int] = {}
|
||||
for c in top:
|
||||
tm = c["config"]["transcription_model"]
|
||||
dm = c["config"]["diarization_model"]
|
||||
t_counts[tm] = t_counts.get(tm, 0) + 1
|
||||
d_counts[dm] = d_counts.get(dm, 0) + 1
|
||||
|
||||
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(11, 4.5))
|
||||
|
||||
t_names = sorted(t_counts.keys(), key=lambda n: t_counts[n], reverse=True)
|
||||
t_labels = [MODEL_DISPLAY.get(n, n) for n in t_names]
|
||||
t_vals = [t_counts[n] for n in t_names]
|
||||
ax1.barh(t_labels, t_vals, color="steelblue")
|
||||
ax1.set_xlabel(f"Количество в топ-{top_n}", fontsize=11)
|
||||
ax1.set_title("Модели транскрибации", fontsize=12)
|
||||
ax1.invert_yaxis()
|
||||
|
||||
d_names = sorted(d_counts.keys(), key=lambda n: d_counts[n], reverse=True)
|
||||
d_labels = [MODEL_DISPLAY.get(n, n) for n in d_names]
|
||||
d_vals = [d_counts[n] for n in d_names]
|
||||
ax2.barh(d_labels, d_vals, color="coral")
|
||||
ax2.set_xlabel(f"Количество в топ-{top_n}", fontsize=11)
|
||||
ax2.set_title("Модели диаризации", fontsize=12)
|
||||
ax2.invert_yaxis()
|
||||
|
||||
fig.tight_layout()
|
||||
fig.savefig(img_dir / "model_frequency.png", dpi=150)
|
||||
plt.close(fig)
|
||||
print(f"Saved {img_dir / 'model_frequency.png'}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user