This commit is contained in:
2026-04-04 16:00:25 +03:00
parent 41c8d97094
commit abb4b31e60
8 changed files with 928 additions and 33 deletions

361
ga/ga.py Normal file
View File

@@ -0,0 +1,361 @@
#!/usr/bin/env python3
"""Genetic algorithm for optimizing meeting transcription+diarization pipeline.
Searches over a mixed discrete configuration space of transcription and
diarization models and their parameters. Uses module-level caching and batch
scheduling grouped by model to minimize redundant computations.
"""
import json
import random
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
import run_pipeline
# ---------------------------------------------------------------------------
# Configuration space
# ---------------------------------------------------------------------------
TRANSCRIPTION_MODELS = [
"whisper-large-v3",
"whisper-medium",
"faster-whisper-large-v3",
"gigaam-ctc",
"gigaam-rnnt",
]
BEAM_SIZES = [1, 3, 5, 7, 10]
VAD_THRESHOLDS = [0.3, 0.4, 0.5, 0.6, 0.7]
DIARIZATION_MODELS = ["pyannote-3.1", "pyannote-community-1", "sortformer"]
MIN_SPEECH_DURATIONS = [0.25, 0.5, 0.75, 1.0, 1.5]
CLUSTERING_THRESHOLDS = [0.3, 0.45, 0.6, 0.75, 0.9]
WHISPER_MODELS = {"whisper-large-v3", "whisper-medium", "faster-whisper-large-v3"}
GENES: list[tuple[str, list]] = [
("transcription_model", TRANSCRIPTION_MODELS),
("beam_size", BEAM_SIZES),
("vad_threshold", VAD_THRESHOLDS),
("diarization_model", DIARIZATION_MODELS),
("min_speech_duration", MIN_SPEECH_DURATIONS),
("clustering_threshold", CLUSTERING_THRESHOLDS),
]
# ---------------------------------------------------------------------------
# GA hyper-parameters
# ---------------------------------------------------------------------------
POPULATION_SIZE = 15
NUM_GENERATIONS = 25
TOURNAMENT_SIZE = 3
MUTATION_PROB = 0.15
ELITE_COUNT = 2
ALPHA = 0.4 # WER weight
BETA = 0.4 # DER weight
GAMMA = 0.2 # time weight
# ---------------------------------------------------------------------------
# Chromosome
# ---------------------------------------------------------------------------
@dataclass
class Chromosome:
genes: list[int]
fitness: float | None = None
wer: float | None = None
der: float | None = None
time_min: float | None = None
def to_config(self) -> dict:
return {
name: values[self.genes[i]] for i, (name, values) in enumerate(GENES)
}
def transcription_key(self) -> tuple:
cfg = self.to_config()
model = cfg["transcription_model"]
beam = cfg["beam_size"] if model in WHISPER_MODELS else 1
return (model, beam, cfg["vad_threshold"])
def diarization_key(self) -> tuple:
cfg = self.to_config()
return (
cfg["diarization_model"],
cfg["min_speech_duration"],
cfg["clustering_threshold"],
cfg["vad_threshold"],
)
def copy(self) -> "Chromosome":
return Chromosome(genes=self.genes.copy())
# ---------------------------------------------------------------------------
# Module-level cache
# ---------------------------------------------------------------------------
class Cache:
def __init__(self, cache_dir: Path):
self.cache_dir = cache_dir
self.cache_dir.mkdir(parents=True, exist_ok=True)
self.transcription: dict[str, dict] = {}
self.diarization: dict[str, dict] = {}
self._load()
def _load(self):
for name in ("transcription", "diarization"):
path = self.cache_dir / f"{name}.json"
if path.exists():
setattr(self, name, json.loads(path.read_text()))
def save(self):
for name in ("transcription", "diarization"):
path = self.cache_dir / f"{name}.json"
path.write_text(json.dumps(getattr(self, name), indent=2))
def get_transcription(self, key: tuple) -> dict | None:
return self.transcription.get(str(key))
def set_transcription(self, key: tuple, result: dict):
self.transcription[str(key)] = result
def get_diarization(self, key: tuple) -> dict | None:
return self.diarization.get(str(key))
def set_diarization(self, key: tuple, result: dict):
self.diarization[str(key)] = result
# ---------------------------------------------------------------------------
# GA operators
# ---------------------------------------------------------------------------
def random_chromosome() -> Chromosome:
return Chromosome(genes=[random.randint(0, len(v) - 1) for _, v in GENES])
def tournament_select(population: list[Chromosome]) -> Chromosome:
candidates = random.sample(population, TOURNAMENT_SIZE)
return max(candidates, key=lambda c: c.fitness)
def crossover(p1: Chromosome, p2: Chromosome) -> Chromosome:
return Chromosome(
genes=[random.choice([g1, g2]) for g1, g2 in zip(p1.genes, p2.genes)]
)
def mutate(chrom: Chromosome) -> Chromosome:
genes = chrom.genes.copy()
for i, (_, values) in enumerate(GENES):
if random.random() < MUTATION_PROB:
if len(values) > 2 and random.random() < 0.7:
delta = random.choice([-1, 1])
genes[i] = max(0, min(len(values) - 1, genes[i] + delta))
else:
genes[i] = random.randint(0, len(values) - 1)
return Chromosome(genes=genes)
def compute_fitness(wer: float, der: float, time_min: float) -> float:
return -(ALPHA * wer + BETA * der + GAMMA * time_min)
# ---------------------------------------------------------------------------
# Batch scheduler
# ---------------------------------------------------------------------------
def schedule_evaluations(
population: list[Chromosome], cache: Cache, audio_paths: list[str]
) -> int:
"""Evaluate chromosomes using cache and batching by model.
1. Collect unique uncached transcription and diarization configs.
2. Group them by model so the pipeline loads each model only once.
3. Store results in cache and assemble fitness values.
Returns the number of new (uncached) module evaluations performed.
"""
uncached_t: dict[str, list[tuple[tuple, dict]]] = defaultdict(list)
uncached_d: dict[str, list[tuple[tuple, dict]]] = defaultdict(list)
seen_t: set[str] = set()
seen_d: set[str] = set()
for chrom in population:
cfg = chrom.to_config()
t_key = chrom.transcription_key()
t_key_s = str(t_key)
if cache.get_transcription(t_key) is None and t_key_s not in seen_t:
seen_t.add(t_key_s)
model = cfg["transcription_model"]
beam = cfg["beam_size"] if model in WHISPER_MODELS else 1
uncached_t[model].append(
(t_key, {"beam_size": beam, "vad_threshold": cfg["vad_threshold"]})
)
d_key = chrom.diarization_key()
d_key_s = str(d_key)
if cache.get_diarization(d_key) is None and d_key_s not in seen_d:
seen_d.add(d_key_s)
uncached_d[cfg["diarization_model"]].append(
(
d_key,
{
"min_speech_duration": cfg["min_speech_duration"],
"clustering_threshold": cfg["clustering_threshold"],
"vad_threshold": cfg["vad_threshold"],
},
)
)
new_evals = 0
for model, items in uncached_t.items():
configs = [c for _, c in items]
results = run_pipeline.evaluate_transcription_batch(
model, configs, audio_paths
)
for (key, _), result in zip(items, results):
cache.set_transcription(key, result)
new_evals += 1
for model, items in uncached_d.items():
configs = [c for _, c in items]
results = run_pipeline.evaluate_diarization_batch(
model, configs, audio_paths
)
for (key, _), result in zip(items, results):
cache.set_diarization(key, result)
new_evals += 1
if new_evals > 0:
cache.save()
for chrom in population:
t_res = cache.get_transcription(chrom.transcription_key())
d_res = cache.get_diarization(chrom.diarization_key())
chrom.wer = t_res["wer"]
chrom.der = d_res["der"]
chrom.time_min = t_res["time"] + d_res["time"]
chrom.fitness = compute_fitness(chrom.wer, chrom.der, chrom.time_min)
return new_evals
# ---------------------------------------------------------------------------
# Main GA loop
# ---------------------------------------------------------------------------
def run_ga(audio_paths: list[str] | None = None, seed: int = 42) -> list[dict]:
random.seed(seed)
if audio_paths is None:
audio_paths = []
cache = Cache(Path(__file__).parent / "cache")
history: list[dict] = []
all_configs: list[dict] = []
seen_genes: set[tuple[int, ...]] = set()
total_evals = 0
population = [random_chromosome() for _ in range(POPULATION_SIZE)]
new_evals = schedule_evaluations(population, cache, audio_paths)
total_evals += new_evals
for gen in range(NUM_GENERATIONS):
population.sort(key=lambda c: c.fitness, reverse=True)
for chrom in population:
key = tuple(chrom.genes)
if key not in seen_genes:
seen_genes.add(key)
all_configs.append(
{
"config": chrom.to_config(),
"wer": chrom.wer,
"der": chrom.der,
"time": chrom.time_min,
"fitness": chrom.fitness,
"generation": gen,
}
)
best = population[0]
mean_fit = sum(c.fitness for c in population) / len(population)
history.append(
{
"generation": gen,
"best_fitness": round(best.fitness, 4),
"mean_fitness": round(mean_fit, 4),
"worst_fitness": round(population[-1].fitness, 4),
"best_config": best.to_config(),
"best_wer": best.wer,
"best_der": best.der,
"best_time": best.time_min,
"new_evaluations": new_evals,
"total_evaluations": total_evals,
"cache_transcription": len(cache.transcription),
"cache_diarization": len(cache.diarization),
}
)
print(
f"Gen {gen:3d} | best={best.fitness:.3f} mean={mean_fit:.3f} | "
f"WER={best.wer:.1f}% DER={best.der:.1f}% | "
f"new={new_evals} cache_t={len(cache.transcription)} "
f"cache_d={len(cache.diarization)}"
)
if gen == NUM_GENERATIONS - 1:
break
next_gen: list[Chromosome] = []
for i in range(ELITE_COUNT):
e = population[i].copy()
e.fitness = population[i].fitness
e.wer = population[i].wer
e.der = population[i].der
e.time_min = population[i].time_min
next_gen.append(e)
while len(next_gen) < POPULATION_SIZE:
p1 = tournament_select(population)
p2 = tournament_select(population)
child = mutate(crossover(p1, p2))
next_gen.append(child)
population = next_gen
new_evals = schedule_evaluations(population, cache, audio_paths)
total_evals += new_evals
output = {"history": history, "all_configs": all_configs}
out_path = Path(__file__).parent / "history.json"
out_path.write_text(json.dumps(output, indent=2, ensure_ascii=False))
print(f"\nResults saved to {out_path}")
population.sort(key=lambda c: c.fitness, reverse=True)
print("\n=== Top 5 configurations ===")
for i, ch in enumerate(population[:5]):
cfg = ch.to_config()
print(
f"\n#{i + 1}: fitness={ch.fitness:.3f} "
f"WER={ch.wer:.2f}% DER={ch.der:.2f}% time={ch.time_min:.2f}min"
)
for k, v in cfg.items():
print(f" {k}: {v}")
return history
if __name__ == "__main__":
run_ga()