362 lines
12 KiB
Python
362 lines
12 KiB
Python
#!/usr/bin/env python3
|
|
"""Genetic algorithm for optimizing meeting transcription+diarization pipeline.
|
|
|
|
Searches over a mixed discrete configuration space of transcription and
|
|
diarization models and their parameters. Uses module-level caching and batch
|
|
scheduling grouped by model to minimize redundant computations.
|
|
"""
|
|
|
|
import json
|
|
import random
|
|
from collections import defaultdict
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
|
|
import run_pipeline
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Configuration space
|
|
# ---------------------------------------------------------------------------
|
|
|
|
TRANSCRIPTION_MODELS = [
|
|
"whisper-large-v3",
|
|
"whisper-medium",
|
|
"faster-whisper-large-v3",
|
|
"gigaam-ctc",
|
|
"gigaam-rnnt",
|
|
]
|
|
BEAM_SIZES = [1, 3, 5, 7, 10]
|
|
VAD_THRESHOLDS = [0.3, 0.4, 0.5, 0.6, 0.7]
|
|
|
|
DIARIZATION_MODELS = ["pyannote-3.1", "pyannote-community-1", "sortformer"]
|
|
MIN_SPEECH_DURATIONS = [0.25, 0.5, 0.75, 1.0, 1.5]
|
|
CLUSTERING_THRESHOLDS = [0.3, 0.45, 0.6, 0.75, 0.9]
|
|
|
|
WHISPER_MODELS = {"whisper-large-v3", "whisper-medium", "faster-whisper-large-v3"}
|
|
|
|
GENES: list[tuple[str, list]] = [
|
|
("transcription_model", TRANSCRIPTION_MODELS),
|
|
("beam_size", BEAM_SIZES),
|
|
("vad_threshold", VAD_THRESHOLDS),
|
|
("diarization_model", DIARIZATION_MODELS),
|
|
("min_speech_duration", MIN_SPEECH_DURATIONS),
|
|
("clustering_threshold", CLUSTERING_THRESHOLDS),
|
|
]
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# GA hyper-parameters
|
|
# ---------------------------------------------------------------------------
|
|
|
|
POPULATION_SIZE = 15
|
|
NUM_GENERATIONS = 25
|
|
TOURNAMENT_SIZE = 3
|
|
MUTATION_PROB = 0.15
|
|
ELITE_COUNT = 2
|
|
|
|
ALPHA = 0.4 # WER weight
|
|
BETA = 0.4 # DER weight
|
|
GAMMA = 0.2 # time weight
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Chromosome
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@dataclass
|
|
class Chromosome:
|
|
genes: list[int]
|
|
fitness: float | None = None
|
|
wer: float | None = None
|
|
der: float | None = None
|
|
time_min: float | None = None
|
|
|
|
def to_config(self) -> dict:
|
|
return {
|
|
name: values[self.genes[i]] for i, (name, values) in enumerate(GENES)
|
|
}
|
|
|
|
def transcription_key(self) -> tuple:
|
|
cfg = self.to_config()
|
|
model = cfg["transcription_model"]
|
|
beam = cfg["beam_size"] if model in WHISPER_MODELS else 1
|
|
return (model, beam, cfg["vad_threshold"])
|
|
|
|
def diarization_key(self) -> tuple:
|
|
cfg = self.to_config()
|
|
return (
|
|
cfg["diarization_model"],
|
|
cfg["min_speech_duration"],
|
|
cfg["clustering_threshold"],
|
|
cfg["vad_threshold"],
|
|
)
|
|
|
|
def copy(self) -> "Chromosome":
|
|
return Chromosome(genes=self.genes.copy())
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Module-level cache
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class Cache:
|
|
def __init__(self, cache_dir: Path):
|
|
self.cache_dir = cache_dir
|
|
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
|
self.transcription: dict[str, dict] = {}
|
|
self.diarization: dict[str, dict] = {}
|
|
self._load()
|
|
|
|
def _load(self):
|
|
for name in ("transcription", "diarization"):
|
|
path = self.cache_dir / f"{name}.json"
|
|
if path.exists():
|
|
setattr(self, name, json.loads(path.read_text()))
|
|
|
|
def save(self):
|
|
for name in ("transcription", "diarization"):
|
|
path = self.cache_dir / f"{name}.json"
|
|
path.write_text(json.dumps(getattr(self, name), indent=2))
|
|
|
|
def get_transcription(self, key: tuple) -> dict | None:
|
|
return self.transcription.get(str(key))
|
|
|
|
def set_transcription(self, key: tuple, result: dict):
|
|
self.transcription[str(key)] = result
|
|
|
|
def get_diarization(self, key: tuple) -> dict | None:
|
|
return self.diarization.get(str(key))
|
|
|
|
def set_diarization(self, key: tuple, result: dict):
|
|
self.diarization[str(key)] = result
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# GA operators
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def random_chromosome() -> Chromosome:
|
|
return Chromosome(genes=[random.randint(0, len(v) - 1) for _, v in GENES])
|
|
|
|
|
|
def tournament_select(population: list[Chromosome]) -> Chromosome:
|
|
candidates = random.sample(population, TOURNAMENT_SIZE)
|
|
return max(candidates, key=lambda c: c.fitness)
|
|
|
|
|
|
def crossover(p1: Chromosome, p2: Chromosome) -> Chromosome:
|
|
return Chromosome(
|
|
genes=[random.choice([g1, g2]) for g1, g2 in zip(p1.genes, p2.genes)]
|
|
)
|
|
|
|
|
|
def mutate(chrom: Chromosome) -> Chromosome:
|
|
genes = chrom.genes.copy()
|
|
for i, (_, values) in enumerate(GENES):
|
|
if random.random() < MUTATION_PROB:
|
|
if len(values) > 2 and random.random() < 0.7:
|
|
delta = random.choice([-1, 1])
|
|
genes[i] = max(0, min(len(values) - 1, genes[i] + delta))
|
|
else:
|
|
genes[i] = random.randint(0, len(values) - 1)
|
|
return Chromosome(genes=genes)
|
|
|
|
|
|
def compute_fitness(wer: float, der: float, time_min: float) -> float:
|
|
return -(ALPHA * wer + BETA * der + GAMMA * time_min)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Batch scheduler
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def schedule_evaluations(
|
|
population: list[Chromosome], cache: Cache, audio_paths: list[str]
|
|
) -> int:
|
|
"""Evaluate chromosomes using cache and batching by model.
|
|
|
|
1. Collect unique uncached transcription and diarization configs.
|
|
2. Group them by model so the pipeline loads each model only once.
|
|
3. Store results in cache and assemble fitness values.
|
|
|
|
Returns the number of new (uncached) module evaluations performed.
|
|
"""
|
|
uncached_t: dict[str, list[tuple[tuple, dict]]] = defaultdict(list)
|
|
uncached_d: dict[str, list[tuple[tuple, dict]]] = defaultdict(list)
|
|
seen_t: set[str] = set()
|
|
seen_d: set[str] = set()
|
|
|
|
for chrom in population:
|
|
cfg = chrom.to_config()
|
|
|
|
t_key = chrom.transcription_key()
|
|
t_key_s = str(t_key)
|
|
if cache.get_transcription(t_key) is None and t_key_s not in seen_t:
|
|
seen_t.add(t_key_s)
|
|
model = cfg["transcription_model"]
|
|
beam = cfg["beam_size"] if model in WHISPER_MODELS else 1
|
|
uncached_t[model].append(
|
|
(t_key, {"beam_size": beam, "vad_threshold": cfg["vad_threshold"]})
|
|
)
|
|
|
|
d_key = chrom.diarization_key()
|
|
d_key_s = str(d_key)
|
|
if cache.get_diarization(d_key) is None and d_key_s not in seen_d:
|
|
seen_d.add(d_key_s)
|
|
uncached_d[cfg["diarization_model"]].append(
|
|
(
|
|
d_key,
|
|
{
|
|
"min_speech_duration": cfg["min_speech_duration"],
|
|
"clustering_threshold": cfg["clustering_threshold"],
|
|
"vad_threshold": cfg["vad_threshold"],
|
|
},
|
|
)
|
|
)
|
|
|
|
new_evals = 0
|
|
|
|
for model, items in uncached_t.items():
|
|
configs = [c for _, c in items]
|
|
results = run_pipeline.evaluate_transcription_batch(
|
|
model, configs, audio_paths
|
|
)
|
|
for (key, _), result in zip(items, results):
|
|
cache.set_transcription(key, result)
|
|
new_evals += 1
|
|
|
|
for model, items in uncached_d.items():
|
|
configs = [c for _, c in items]
|
|
results = run_pipeline.evaluate_diarization_batch(
|
|
model, configs, audio_paths
|
|
)
|
|
for (key, _), result in zip(items, results):
|
|
cache.set_diarization(key, result)
|
|
new_evals += 1
|
|
|
|
if new_evals > 0:
|
|
cache.save()
|
|
|
|
for chrom in population:
|
|
t_res = cache.get_transcription(chrom.transcription_key())
|
|
d_res = cache.get_diarization(chrom.diarization_key())
|
|
chrom.wer = t_res["wer"]
|
|
chrom.der = d_res["der"]
|
|
chrom.time_min = t_res["time"] + d_res["time"]
|
|
chrom.fitness = compute_fitness(chrom.wer, chrom.der, chrom.time_min)
|
|
|
|
return new_evals
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Main GA loop
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def run_ga(audio_paths: list[str] | None = None, seed: int = 42) -> list[dict]:
|
|
random.seed(seed)
|
|
if audio_paths is None:
|
|
audio_paths = []
|
|
|
|
cache = Cache(Path(__file__).parent / "cache")
|
|
history: list[dict] = []
|
|
all_configs: list[dict] = []
|
|
seen_genes: set[tuple[int, ...]] = set()
|
|
total_evals = 0
|
|
|
|
population = [random_chromosome() for _ in range(POPULATION_SIZE)]
|
|
new_evals = schedule_evaluations(population, cache, audio_paths)
|
|
total_evals += new_evals
|
|
|
|
for gen in range(NUM_GENERATIONS):
|
|
population.sort(key=lambda c: c.fitness, reverse=True)
|
|
|
|
for chrom in population:
|
|
key = tuple(chrom.genes)
|
|
if key not in seen_genes:
|
|
seen_genes.add(key)
|
|
all_configs.append(
|
|
{
|
|
"config": chrom.to_config(),
|
|
"wer": chrom.wer,
|
|
"der": chrom.der,
|
|
"time": chrom.time_min,
|
|
"fitness": chrom.fitness,
|
|
"generation": gen,
|
|
}
|
|
)
|
|
|
|
best = population[0]
|
|
mean_fit = sum(c.fitness for c in population) / len(population)
|
|
|
|
history.append(
|
|
{
|
|
"generation": gen,
|
|
"best_fitness": round(best.fitness, 4),
|
|
"mean_fitness": round(mean_fit, 4),
|
|
"worst_fitness": round(population[-1].fitness, 4),
|
|
"best_config": best.to_config(),
|
|
"best_wer": best.wer,
|
|
"best_der": best.der,
|
|
"best_time": best.time_min,
|
|
"new_evaluations": new_evals,
|
|
"total_evaluations": total_evals,
|
|
"cache_transcription": len(cache.transcription),
|
|
"cache_diarization": len(cache.diarization),
|
|
}
|
|
)
|
|
|
|
print(
|
|
f"Gen {gen:3d} | best={best.fitness:.3f} mean={mean_fit:.3f} | "
|
|
f"WER={best.wer:.1f}% DER={best.der:.1f}% | "
|
|
f"new={new_evals} cache_t={len(cache.transcription)} "
|
|
f"cache_d={len(cache.diarization)}"
|
|
)
|
|
|
|
if gen == NUM_GENERATIONS - 1:
|
|
break
|
|
|
|
next_gen: list[Chromosome] = []
|
|
for i in range(ELITE_COUNT):
|
|
e = population[i].copy()
|
|
e.fitness = population[i].fitness
|
|
e.wer = population[i].wer
|
|
e.der = population[i].der
|
|
e.time_min = population[i].time_min
|
|
next_gen.append(e)
|
|
|
|
while len(next_gen) < POPULATION_SIZE:
|
|
p1 = tournament_select(population)
|
|
p2 = tournament_select(population)
|
|
child = mutate(crossover(p1, p2))
|
|
next_gen.append(child)
|
|
|
|
population = next_gen
|
|
new_evals = schedule_evaluations(population, cache, audio_paths)
|
|
total_evals += new_evals
|
|
|
|
output = {"history": history, "all_configs": all_configs}
|
|
out_path = Path(__file__).parent / "history.json"
|
|
out_path.write_text(json.dumps(output, indent=2, ensure_ascii=False))
|
|
print(f"\nResults saved to {out_path}")
|
|
|
|
population.sort(key=lambda c: c.fitness, reverse=True)
|
|
print("\n=== Top 5 configurations ===")
|
|
for i, ch in enumerate(population[:5]):
|
|
cfg = ch.to_config()
|
|
print(
|
|
f"\n#{i + 1}: fitness={ch.fitness:.3f} "
|
|
f"WER={ch.wer:.2f}% DER={ch.der:.2f}% time={ch.time_min:.2f}min"
|
|
)
|
|
for k, v in cfg.items():
|
|
print(f" {k}: {v}")
|
|
|
|
return history
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_ga()
|