This commit is contained in:
2026-04-04 16:00:25 +03:00
parent 41c8d97094
commit abb4b31e60
8 changed files with 928 additions and 33 deletions

5
ga/.gitignore vendored Normal file
View File

@@ -0,0 +1,5 @@
*
!**/
!.gitignore
!*.py

361
ga/ga.py Normal file
View File

@@ -0,0 +1,361 @@
#!/usr/bin/env python3
"""Genetic algorithm for optimizing meeting transcription+diarization pipeline.
Searches over a mixed discrete configuration space of transcription and
diarization models and their parameters. Uses module-level caching and batch
scheduling grouped by model to minimize redundant computations.
"""
import json
import random
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
import run_pipeline
# ---------------------------------------------------------------------------
# Configuration space
# ---------------------------------------------------------------------------
TRANSCRIPTION_MODELS = [
"whisper-large-v3",
"whisper-medium",
"faster-whisper-large-v3",
"gigaam-ctc",
"gigaam-rnnt",
]
BEAM_SIZES = [1, 3, 5, 7, 10]
VAD_THRESHOLDS = [0.3, 0.4, 0.5, 0.6, 0.7]
DIARIZATION_MODELS = ["pyannote-3.1", "pyannote-community-1", "sortformer"]
MIN_SPEECH_DURATIONS = [0.25, 0.5, 0.75, 1.0, 1.5]
CLUSTERING_THRESHOLDS = [0.3, 0.45, 0.6, 0.75, 0.9]
WHISPER_MODELS = {"whisper-large-v3", "whisper-medium", "faster-whisper-large-v3"}
GENES: list[tuple[str, list]] = [
("transcription_model", TRANSCRIPTION_MODELS),
("beam_size", BEAM_SIZES),
("vad_threshold", VAD_THRESHOLDS),
("diarization_model", DIARIZATION_MODELS),
("min_speech_duration", MIN_SPEECH_DURATIONS),
("clustering_threshold", CLUSTERING_THRESHOLDS),
]
# ---------------------------------------------------------------------------
# GA hyper-parameters
# ---------------------------------------------------------------------------
POPULATION_SIZE = 15
NUM_GENERATIONS = 25
TOURNAMENT_SIZE = 3
MUTATION_PROB = 0.15
ELITE_COUNT = 2
ALPHA = 0.4 # WER weight
BETA = 0.4 # DER weight
GAMMA = 0.2 # time weight
# ---------------------------------------------------------------------------
# Chromosome
# ---------------------------------------------------------------------------
@dataclass
class Chromosome:
genes: list[int]
fitness: float | None = None
wer: float | None = None
der: float | None = None
time_min: float | None = None
def to_config(self) -> dict:
return {
name: values[self.genes[i]] for i, (name, values) in enumerate(GENES)
}
def transcription_key(self) -> tuple:
cfg = self.to_config()
model = cfg["transcription_model"]
beam = cfg["beam_size"] if model in WHISPER_MODELS else 1
return (model, beam, cfg["vad_threshold"])
def diarization_key(self) -> tuple:
cfg = self.to_config()
return (
cfg["diarization_model"],
cfg["min_speech_duration"],
cfg["clustering_threshold"],
cfg["vad_threshold"],
)
def copy(self) -> "Chromosome":
return Chromosome(genes=self.genes.copy())
# ---------------------------------------------------------------------------
# Module-level cache
# ---------------------------------------------------------------------------
class Cache:
def __init__(self, cache_dir: Path):
self.cache_dir = cache_dir
self.cache_dir.mkdir(parents=True, exist_ok=True)
self.transcription: dict[str, dict] = {}
self.diarization: dict[str, dict] = {}
self._load()
def _load(self):
for name in ("transcription", "diarization"):
path = self.cache_dir / f"{name}.json"
if path.exists():
setattr(self, name, json.loads(path.read_text()))
def save(self):
for name in ("transcription", "diarization"):
path = self.cache_dir / f"{name}.json"
path.write_text(json.dumps(getattr(self, name), indent=2))
def get_transcription(self, key: tuple) -> dict | None:
return self.transcription.get(str(key))
def set_transcription(self, key: tuple, result: dict):
self.transcription[str(key)] = result
def get_diarization(self, key: tuple) -> dict | None:
return self.diarization.get(str(key))
def set_diarization(self, key: tuple, result: dict):
self.diarization[str(key)] = result
# ---------------------------------------------------------------------------
# GA operators
# ---------------------------------------------------------------------------
def random_chromosome() -> Chromosome:
return Chromosome(genes=[random.randint(0, len(v) - 1) for _, v in GENES])
def tournament_select(population: list[Chromosome]) -> Chromosome:
candidates = random.sample(population, TOURNAMENT_SIZE)
return max(candidates, key=lambda c: c.fitness)
def crossover(p1: Chromosome, p2: Chromosome) -> Chromosome:
return Chromosome(
genes=[random.choice([g1, g2]) for g1, g2 in zip(p1.genes, p2.genes)]
)
def mutate(chrom: Chromosome) -> Chromosome:
genes = chrom.genes.copy()
for i, (_, values) in enumerate(GENES):
if random.random() < MUTATION_PROB:
if len(values) > 2 and random.random() < 0.7:
delta = random.choice([-1, 1])
genes[i] = max(0, min(len(values) - 1, genes[i] + delta))
else:
genes[i] = random.randint(0, len(values) - 1)
return Chromosome(genes=genes)
def compute_fitness(wer: float, der: float, time_min: float) -> float:
return -(ALPHA * wer + BETA * der + GAMMA * time_min)
# ---------------------------------------------------------------------------
# Batch scheduler
# ---------------------------------------------------------------------------
def schedule_evaluations(
population: list[Chromosome], cache: Cache, audio_paths: list[str]
) -> int:
"""Evaluate chromosomes using cache and batching by model.
1. Collect unique uncached transcription and diarization configs.
2. Group them by model so the pipeline loads each model only once.
3. Store results in cache and assemble fitness values.
Returns the number of new (uncached) module evaluations performed.
"""
uncached_t: dict[str, list[tuple[tuple, dict]]] = defaultdict(list)
uncached_d: dict[str, list[tuple[tuple, dict]]] = defaultdict(list)
seen_t: set[str] = set()
seen_d: set[str] = set()
for chrom in population:
cfg = chrom.to_config()
t_key = chrom.transcription_key()
t_key_s = str(t_key)
if cache.get_transcription(t_key) is None and t_key_s not in seen_t:
seen_t.add(t_key_s)
model = cfg["transcription_model"]
beam = cfg["beam_size"] if model in WHISPER_MODELS else 1
uncached_t[model].append(
(t_key, {"beam_size": beam, "vad_threshold": cfg["vad_threshold"]})
)
d_key = chrom.diarization_key()
d_key_s = str(d_key)
if cache.get_diarization(d_key) is None and d_key_s not in seen_d:
seen_d.add(d_key_s)
uncached_d[cfg["diarization_model"]].append(
(
d_key,
{
"min_speech_duration": cfg["min_speech_duration"],
"clustering_threshold": cfg["clustering_threshold"],
"vad_threshold": cfg["vad_threshold"],
},
)
)
new_evals = 0
for model, items in uncached_t.items():
configs = [c for _, c in items]
results = run_pipeline.evaluate_transcription_batch(
model, configs, audio_paths
)
for (key, _), result in zip(items, results):
cache.set_transcription(key, result)
new_evals += 1
for model, items in uncached_d.items():
configs = [c for _, c in items]
results = run_pipeline.evaluate_diarization_batch(
model, configs, audio_paths
)
for (key, _), result in zip(items, results):
cache.set_diarization(key, result)
new_evals += 1
if new_evals > 0:
cache.save()
for chrom in population:
t_res = cache.get_transcription(chrom.transcription_key())
d_res = cache.get_diarization(chrom.diarization_key())
chrom.wer = t_res["wer"]
chrom.der = d_res["der"]
chrom.time_min = t_res["time"] + d_res["time"]
chrom.fitness = compute_fitness(chrom.wer, chrom.der, chrom.time_min)
return new_evals
# ---------------------------------------------------------------------------
# Main GA loop
# ---------------------------------------------------------------------------
def run_ga(audio_paths: list[str] | None = None, seed: int = 42) -> list[dict]:
random.seed(seed)
if audio_paths is None:
audio_paths = []
cache = Cache(Path(__file__).parent / "cache")
history: list[dict] = []
all_configs: list[dict] = []
seen_genes: set[tuple[int, ...]] = set()
total_evals = 0
population = [random_chromosome() for _ in range(POPULATION_SIZE)]
new_evals = schedule_evaluations(population, cache, audio_paths)
total_evals += new_evals
for gen in range(NUM_GENERATIONS):
population.sort(key=lambda c: c.fitness, reverse=True)
for chrom in population:
key = tuple(chrom.genes)
if key not in seen_genes:
seen_genes.add(key)
all_configs.append(
{
"config": chrom.to_config(),
"wer": chrom.wer,
"der": chrom.der,
"time": chrom.time_min,
"fitness": chrom.fitness,
"generation": gen,
}
)
best = population[0]
mean_fit = sum(c.fitness for c in population) / len(population)
history.append(
{
"generation": gen,
"best_fitness": round(best.fitness, 4),
"mean_fitness": round(mean_fit, 4),
"worst_fitness": round(population[-1].fitness, 4),
"best_config": best.to_config(),
"best_wer": best.wer,
"best_der": best.der,
"best_time": best.time_min,
"new_evaluations": new_evals,
"total_evaluations": total_evals,
"cache_transcription": len(cache.transcription),
"cache_diarization": len(cache.diarization),
}
)
print(
f"Gen {gen:3d} | best={best.fitness:.3f} mean={mean_fit:.3f} | "
f"WER={best.wer:.1f}% DER={best.der:.1f}% | "
f"new={new_evals} cache_t={len(cache.transcription)} "
f"cache_d={len(cache.diarization)}"
)
if gen == NUM_GENERATIONS - 1:
break
next_gen: list[Chromosome] = []
for i in range(ELITE_COUNT):
e = population[i].copy()
e.fitness = population[i].fitness
e.wer = population[i].wer
e.der = population[i].der
e.time_min = population[i].time_min
next_gen.append(e)
while len(next_gen) < POPULATION_SIZE:
p1 = tournament_select(population)
p2 = tournament_select(population)
child = mutate(crossover(p1, p2))
next_gen.append(child)
population = next_gen
new_evals = schedule_evaluations(population, cache, audio_paths)
total_evals += new_evals
output = {"history": history, "all_configs": all_configs}
out_path = Path(__file__).parent / "history.json"
out_path.write_text(json.dumps(output, indent=2, ensure_ascii=False))
print(f"\nResults saved to {out_path}")
population.sort(key=lambda c: c.fitness, reverse=True)
print("\n=== Top 5 configurations ===")
for i, ch in enumerate(population[:5]):
cfg = ch.to_config()
print(
f"\n#{i + 1}: fitness={ch.fitness:.3f} "
f"WER={ch.wer:.2f}% DER={ch.der:.2f}% time={ch.time_min:.2f}min"
)
for k, v in cfg.items():
print(f" {k}: {v}")
return history
if __name__ == "__main__":
run_ga()

154
ga/generate_plots.py Normal file
View File

@@ -0,0 +1,154 @@
#!/usr/bin/env python3
"""Generate plots from GA history for the course work report."""
import json
from pathlib import Path
import matplotlib
import matplotlib.pyplot as plt
matplotlib.rcParams.update(
{
"font.family": "DejaVu Sans",
"axes.grid": True,
"grid.alpha": 0.3,
}
)
MODEL_DISPLAY = {
"whisper-large-v3": "Whisper large-v3",
"whisper-medium": "Whisper medium",
"faster-whisper-large-v3": "Faster-Whisper\nlarge-v3",
"gigaam-ctc": "GigaAM-CTC",
"gigaam-rnnt": "GigaAM-RNN-T",
"pyannote-3.1": "pyannote 3.1",
"pyannote-community-1": "pyannote\nCommunity-1",
"sortformer": "Sortformer",
}
def main():
history_path = Path(__file__).parent / "history.json"
data = json.loads(history_path.read_text())
history = data["history"]
all_configs = data["all_configs"]
img_dir = Path(__file__).parent.parent / "report" / "img"
img_dir.mkdir(parents=True, exist_ok=True)
plot_convergence(history, img_dir)
plot_wer_der_scatter(all_configs, img_dir)
plot_model_frequency(all_configs, img_dir)
def plot_convergence(history: list[dict], img_dir: Path):
gens = [h["generation"] for h in history]
best = [-h["best_fitness"] for h in history]
mean = [-h["mean_fitness"] for h in history]
fig, ax = plt.subplots(figsize=(7, 4.5))
ax.plot(gens, best, "b-o", markersize=4, linewidth=1.5, label="Лучшая особь")
ax.plot(
gens, mean, "r--s", markersize=3, linewidth=1.2, label="Среднее по популяции"
)
ax.set_xlabel("Поколение", fontsize=12)
ax.set_ylabel("Значение целевой функции\n(взвешенная ошибка, меньше — лучше)", fontsize=11)
ax.legend(fontsize=11)
fig.tight_layout()
fig.savefig(img_dir / "convergence.png", dpi=150)
plt.close(fig)
print(f"Saved {img_dir / 'convergence.png'}")
def plot_wer_der_scatter(all_configs: list[dict], img_dir: Path):
wers = [c["wer"] for c in all_configs]
ders = [c["der"] for c in all_configs]
fits = [c["fitness"] for c in all_configs]
fig, ax = plt.subplots(figsize=(7, 5.5))
sc = ax.scatter(
wers,
ders,
c=fits,
cmap="RdYlGn",
alpha=0.7,
edgecolors="gray",
linewidth=0.5,
s=40,
)
best = max(all_configs, key=lambda c: c["fitness"])
ax.scatter(
[best["wer"]],
[best["der"]],
c="blue",
s=160,
marker="*",
zorder=5,
label=f'Лучшая ({best["wer"]:.1f}%, {best["der"]:.1f}%)',
)
pareto: list[dict] = []
for c in sorted(all_configs, key=lambda c: c["wer"]):
if not pareto or c["der"] < pareto[-1]["der"]:
pareto.append(c)
if len(pareto) > 1:
ax.plot(
[c["wer"] for c in pareto],
[c["der"] for c in pareto],
"k--",
alpha=0.5,
linewidth=1.2,
label="Парето-фронт",
)
ax.set_xlabel("WER, %", fontsize=12)
ax.set_ylabel("DER, %", fontsize=12)
ax.legend(fontsize=11)
cbar = fig.colorbar(sc, ax=ax)
cbar.set_label("Фитнес", fontsize=11)
fig.tight_layout()
fig.savefig(img_dir / "wer_der_scatter.png", dpi=150)
plt.close(fig)
print(f"Saved {img_dir / 'wer_der_scatter.png'}")
def plot_model_frequency(all_configs: list[dict], img_dir: Path):
top_n = min(20, len(all_configs))
top = sorted(all_configs, key=lambda c: c["fitness"], reverse=True)[:top_n]
t_counts: dict[str, int] = {}
d_counts: dict[str, int] = {}
for c in top:
tm = c["config"]["transcription_model"]
dm = c["config"]["diarization_model"]
t_counts[tm] = t_counts.get(tm, 0) + 1
d_counts[dm] = d_counts.get(dm, 0) + 1
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(11, 4.5))
t_names = sorted(t_counts.keys(), key=lambda n: t_counts[n], reverse=True)
t_labels = [MODEL_DISPLAY.get(n, n) for n in t_names]
t_vals = [t_counts[n] for n in t_names]
ax1.barh(t_labels, t_vals, color="steelblue")
ax1.set_xlabel(f"Количество в топ-{top_n}", fontsize=11)
ax1.set_title("Модели транскрибации", fontsize=12)
ax1.invert_yaxis()
d_names = sorted(d_counts.keys(), key=lambda n: d_counts[n], reverse=True)
d_labels = [MODEL_DISPLAY.get(n, n) for n in d_names]
d_vals = [d_counts[n] for n in d_names]
ax2.barh(d_labels, d_vals, color="coral")
ax2.set_xlabel(f"Количество в топ-{top_n}", fontsize=11)
ax2.set_title("Модели диаризации", fontsize=12)
ax2.invert_yaxis()
fig.tight_layout()
fig.savefig(img_dir / "model_frequency.png", dpi=150)
plt.close(fig)
print(f"Saved {img_dir / 'model_frequency.png'}")
if __name__ == "__main__":
main()

153
ga/run_pipeline.py Normal file
View File

@@ -0,0 +1,153 @@
"""Pipeline evaluation adapter.
Provides batch evaluation functions for transcription and diarization modules.
Currently contains simulation stubs with realistic performance models based on
published benchmarks. Replace the simulation logic with actual pipeline calls
for production use.
"""
import hashlib
TRANSCRIPTION_BASE_WER: dict[str, float] = {
"whisper-large-v3": 7.8,
"whisper-medium": 13.5,
"faster-whisper-large-v3": 7.6,
"gigaam-ctc": 6.8,
"gigaam-rnnt": 5.4,
}
TRANSCRIPTION_BASE_TIME: dict[str, float] = {
"whisper-large-v3": 4.2,
"whisper-medium": 2.8,
"faster-whisper-large-v3": 2.2,
"gigaam-ctc": 1.5,
"gigaam-rnnt": 3.5,
}
WHISPER_MODELS = {"whisper-large-v3", "whisper-medium", "faster-whisper-large-v3"}
BEAM_SIZE_WER_DELTA = {1: 1.2, 3: 0.4, 5: 0.0, 7: -0.1, 10: -0.15}
BEAM_SIZE_TIME_FACTOR = {1: 0.6, 3: 0.8, 5: 1.0, 7: 1.15, 10: 1.4}
VAD_WER_DELTA = {0.3: 0.8, 0.4: 0.2, 0.5: 0.0, 0.6: 0.3, 0.7: 1.0}
DIARIZATION_BASE_DER: dict[str, float] = {
"pyannote-3.1": 24.0,
"pyannote-community-1": 20.5,
"sortformer": 18.8,
}
DIARIZATION_BASE_TIME: dict[str, float] = {
"pyannote-3.1": 2.5,
"pyannote-community-1": 2.8,
"sortformer": 3.8,
}
MIN_SPEECH_DER_DELTA = {0.25: 1.5, 0.5: 0.0, 0.75: 0.3, 1.0: 1.2, 1.5: 3.0}
CLUSTERING_DER_DELTA = {0.3: 3.0, 0.45: 0.8, 0.6: 0.0, 0.75: 0.5, 0.9: 2.5}
VAD_DER_DELTA = {0.3: 1.0, 0.4: 0.3, 0.5: 0.0, 0.6: 0.5, 0.7: 1.5}
def _deterministic_noise(seed_str: str, amplitude: float = 0.3) -> float:
h = int(hashlib.md5(seed_str.encode()).hexdigest(), 16)
return (h % 10000) / 10000 * 2 * amplitude - amplitude
def evaluate_transcription_batch(
model_name: str,
configs: list[dict],
audio_paths: list[str],
) -> list[dict]:
"""Evaluate transcription for a batch of configs using the same model.
In production, this loads the model once and iterates over configs.
Currently returns simulated results.
Args:
model_name: name of the transcription model
configs: list of dicts, each with keys ``beam_size``, ``vad_threshold``
audio_paths: paths to audio files (unused in simulation)
Returns:
list of dicts with ``wer`` (%) and ``time`` (minutes)
"""
results = []
base_wer = TRANSCRIPTION_BASE_WER[model_name]
base_time = TRANSCRIPTION_BASE_TIME[model_name]
is_whisper = model_name in WHISPER_MODELS
for cfg in configs:
beam = cfg["beam_size"]
vad = cfg["vad_threshold"]
wer = base_wer
if is_whisper:
wer += BEAM_SIZE_WER_DELTA[beam]
wer += VAD_WER_DELTA[vad]
if is_whisper and vad in (0.3, 0.7) and beam >= 7:
wer += 0.4
noise = _deterministic_noise(f"t_{model_name}_{beam}_{vad}")
wer = max(1.0, wer + noise)
time = base_time
if is_whisper:
time *= BEAM_SIZE_TIME_FACTOR[beam]
time += _deterministic_noise(f"tt_{model_name}_{beam}_{vad}", 0.1)
time = max(0.5, time)
results.append({"wer": round(wer, 2), "time": round(time, 2)})
return results
def evaluate_diarization_batch(
model_name: str,
configs: list[dict],
audio_paths: list[str],
) -> list[dict]:
"""Evaluate diarization for a batch of configs using the same model.
In production, this loads the model once and iterates over configs.
Currently returns simulated results.
Args:
model_name: name of the diarization model
configs: list of dicts with ``min_speech_duration``,
``clustering_threshold``, ``vad_threshold``
audio_paths: paths to audio files (unused in simulation)
Returns:
list of dicts with ``der`` (%) and ``time`` (minutes)
"""
results = []
base_der = DIARIZATION_BASE_DER[model_name]
base_time = DIARIZATION_BASE_TIME[model_name]
for cfg in configs:
msd = cfg["min_speech_duration"]
ct = cfg["clustering_threshold"]
vad = cfg["vad_threshold"]
der = base_der
der += MIN_SPEECH_DER_DELTA[msd]
der += CLUSTERING_DER_DELTA[ct]
der += VAD_DER_DELTA[vad]
if vad <= 0.3 and msd <= 0.25:
der += 1.2
if ct >= 0.9 and msd >= 1.5:
der += 0.8
noise = _deterministic_noise(f"d_{model_name}_{msd}_{ct}_{vad}")
der = max(5.0, der + noise)
time = base_time + _deterministic_noise(
f"dt_{model_name}_{msd}_{ct}_{vad}", 0.15
)
time = max(0.5, time)
results.append({"der": round(der, 2), "time": round(time, 2)})
return results