га
This commit is contained in:
361
ga/ga.py
Normal file
361
ga/ga.py
Normal file
@@ -0,0 +1,361 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Genetic algorithm for optimizing meeting transcription+diarization pipeline.
|
||||
|
||||
Searches over a mixed discrete configuration space of transcription and
|
||||
diarization models and their parameters. Uses module-level caching and batch
|
||||
scheduling grouped by model to minimize redundant computations.
|
||||
"""
|
||||
|
||||
import json
|
||||
import random
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import run_pipeline
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Configuration space
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
TRANSCRIPTION_MODELS = [
|
||||
"whisper-large-v3",
|
||||
"whisper-medium",
|
||||
"faster-whisper-large-v3",
|
||||
"gigaam-ctc",
|
||||
"gigaam-rnnt",
|
||||
]
|
||||
BEAM_SIZES = [1, 3, 5, 7, 10]
|
||||
VAD_THRESHOLDS = [0.3, 0.4, 0.5, 0.6, 0.7]
|
||||
|
||||
DIARIZATION_MODELS = ["pyannote-3.1", "pyannote-community-1", "sortformer"]
|
||||
MIN_SPEECH_DURATIONS = [0.25, 0.5, 0.75, 1.0, 1.5]
|
||||
CLUSTERING_THRESHOLDS = [0.3, 0.45, 0.6, 0.75, 0.9]
|
||||
|
||||
WHISPER_MODELS = {"whisper-large-v3", "whisper-medium", "faster-whisper-large-v3"}
|
||||
|
||||
GENES: list[tuple[str, list]] = [
|
||||
("transcription_model", TRANSCRIPTION_MODELS),
|
||||
("beam_size", BEAM_SIZES),
|
||||
("vad_threshold", VAD_THRESHOLDS),
|
||||
("diarization_model", DIARIZATION_MODELS),
|
||||
("min_speech_duration", MIN_SPEECH_DURATIONS),
|
||||
("clustering_threshold", CLUSTERING_THRESHOLDS),
|
||||
]
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GA hyper-parameters
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
POPULATION_SIZE = 15
|
||||
NUM_GENERATIONS = 25
|
||||
TOURNAMENT_SIZE = 3
|
||||
MUTATION_PROB = 0.15
|
||||
ELITE_COUNT = 2
|
||||
|
||||
ALPHA = 0.4 # WER weight
|
||||
BETA = 0.4 # DER weight
|
||||
GAMMA = 0.2 # time weight
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Chromosome
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class Chromosome:
|
||||
genes: list[int]
|
||||
fitness: float | None = None
|
||||
wer: float | None = None
|
||||
der: float | None = None
|
||||
time_min: float | None = None
|
||||
|
||||
def to_config(self) -> dict:
|
||||
return {
|
||||
name: values[self.genes[i]] for i, (name, values) in enumerate(GENES)
|
||||
}
|
||||
|
||||
def transcription_key(self) -> tuple:
|
||||
cfg = self.to_config()
|
||||
model = cfg["transcription_model"]
|
||||
beam = cfg["beam_size"] if model in WHISPER_MODELS else 1
|
||||
return (model, beam, cfg["vad_threshold"])
|
||||
|
||||
def diarization_key(self) -> tuple:
|
||||
cfg = self.to_config()
|
||||
return (
|
||||
cfg["diarization_model"],
|
||||
cfg["min_speech_duration"],
|
||||
cfg["clustering_threshold"],
|
||||
cfg["vad_threshold"],
|
||||
)
|
||||
|
||||
def copy(self) -> "Chromosome":
|
||||
return Chromosome(genes=self.genes.copy())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Module-level cache
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class Cache:
|
||||
def __init__(self, cache_dir: Path):
|
||||
self.cache_dir = cache_dir
|
||||
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.transcription: dict[str, dict] = {}
|
||||
self.diarization: dict[str, dict] = {}
|
||||
self._load()
|
||||
|
||||
def _load(self):
|
||||
for name in ("transcription", "diarization"):
|
||||
path = self.cache_dir / f"{name}.json"
|
||||
if path.exists():
|
||||
setattr(self, name, json.loads(path.read_text()))
|
||||
|
||||
def save(self):
|
||||
for name in ("transcription", "diarization"):
|
||||
path = self.cache_dir / f"{name}.json"
|
||||
path.write_text(json.dumps(getattr(self, name), indent=2))
|
||||
|
||||
def get_transcription(self, key: tuple) -> dict | None:
|
||||
return self.transcription.get(str(key))
|
||||
|
||||
def set_transcription(self, key: tuple, result: dict):
|
||||
self.transcription[str(key)] = result
|
||||
|
||||
def get_diarization(self, key: tuple) -> dict | None:
|
||||
return self.diarization.get(str(key))
|
||||
|
||||
def set_diarization(self, key: tuple, result: dict):
|
||||
self.diarization[str(key)] = result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GA operators
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def random_chromosome() -> Chromosome:
|
||||
return Chromosome(genes=[random.randint(0, len(v) - 1) for _, v in GENES])
|
||||
|
||||
|
||||
def tournament_select(population: list[Chromosome]) -> Chromosome:
|
||||
candidates = random.sample(population, TOURNAMENT_SIZE)
|
||||
return max(candidates, key=lambda c: c.fitness)
|
||||
|
||||
|
||||
def crossover(p1: Chromosome, p2: Chromosome) -> Chromosome:
|
||||
return Chromosome(
|
||||
genes=[random.choice([g1, g2]) for g1, g2 in zip(p1.genes, p2.genes)]
|
||||
)
|
||||
|
||||
|
||||
def mutate(chrom: Chromosome) -> Chromosome:
|
||||
genes = chrom.genes.copy()
|
||||
for i, (_, values) in enumerate(GENES):
|
||||
if random.random() < MUTATION_PROB:
|
||||
if len(values) > 2 and random.random() < 0.7:
|
||||
delta = random.choice([-1, 1])
|
||||
genes[i] = max(0, min(len(values) - 1, genes[i] + delta))
|
||||
else:
|
||||
genes[i] = random.randint(0, len(values) - 1)
|
||||
return Chromosome(genes=genes)
|
||||
|
||||
|
||||
def compute_fitness(wer: float, der: float, time_min: float) -> float:
|
||||
return -(ALPHA * wer + BETA * der + GAMMA * time_min)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Batch scheduler
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def schedule_evaluations(
|
||||
population: list[Chromosome], cache: Cache, audio_paths: list[str]
|
||||
) -> int:
|
||||
"""Evaluate chromosomes using cache and batching by model.
|
||||
|
||||
1. Collect unique uncached transcription and diarization configs.
|
||||
2. Group them by model so the pipeline loads each model only once.
|
||||
3. Store results in cache and assemble fitness values.
|
||||
|
||||
Returns the number of new (uncached) module evaluations performed.
|
||||
"""
|
||||
uncached_t: dict[str, list[tuple[tuple, dict]]] = defaultdict(list)
|
||||
uncached_d: dict[str, list[tuple[tuple, dict]]] = defaultdict(list)
|
||||
seen_t: set[str] = set()
|
||||
seen_d: set[str] = set()
|
||||
|
||||
for chrom in population:
|
||||
cfg = chrom.to_config()
|
||||
|
||||
t_key = chrom.transcription_key()
|
||||
t_key_s = str(t_key)
|
||||
if cache.get_transcription(t_key) is None and t_key_s not in seen_t:
|
||||
seen_t.add(t_key_s)
|
||||
model = cfg["transcription_model"]
|
||||
beam = cfg["beam_size"] if model in WHISPER_MODELS else 1
|
||||
uncached_t[model].append(
|
||||
(t_key, {"beam_size": beam, "vad_threshold": cfg["vad_threshold"]})
|
||||
)
|
||||
|
||||
d_key = chrom.diarization_key()
|
||||
d_key_s = str(d_key)
|
||||
if cache.get_diarization(d_key) is None and d_key_s not in seen_d:
|
||||
seen_d.add(d_key_s)
|
||||
uncached_d[cfg["diarization_model"]].append(
|
||||
(
|
||||
d_key,
|
||||
{
|
||||
"min_speech_duration": cfg["min_speech_duration"],
|
||||
"clustering_threshold": cfg["clustering_threshold"],
|
||||
"vad_threshold": cfg["vad_threshold"],
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
new_evals = 0
|
||||
|
||||
for model, items in uncached_t.items():
|
||||
configs = [c for _, c in items]
|
||||
results = run_pipeline.evaluate_transcription_batch(
|
||||
model, configs, audio_paths
|
||||
)
|
||||
for (key, _), result in zip(items, results):
|
||||
cache.set_transcription(key, result)
|
||||
new_evals += 1
|
||||
|
||||
for model, items in uncached_d.items():
|
||||
configs = [c for _, c in items]
|
||||
results = run_pipeline.evaluate_diarization_batch(
|
||||
model, configs, audio_paths
|
||||
)
|
||||
for (key, _), result in zip(items, results):
|
||||
cache.set_diarization(key, result)
|
||||
new_evals += 1
|
||||
|
||||
if new_evals > 0:
|
||||
cache.save()
|
||||
|
||||
for chrom in population:
|
||||
t_res = cache.get_transcription(chrom.transcription_key())
|
||||
d_res = cache.get_diarization(chrom.diarization_key())
|
||||
chrom.wer = t_res["wer"]
|
||||
chrom.der = d_res["der"]
|
||||
chrom.time_min = t_res["time"] + d_res["time"]
|
||||
chrom.fitness = compute_fitness(chrom.wer, chrom.der, chrom.time_min)
|
||||
|
||||
return new_evals
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main GA loop
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def run_ga(audio_paths: list[str] | None = None, seed: int = 42) -> list[dict]:
|
||||
random.seed(seed)
|
||||
if audio_paths is None:
|
||||
audio_paths = []
|
||||
|
||||
cache = Cache(Path(__file__).parent / "cache")
|
||||
history: list[dict] = []
|
||||
all_configs: list[dict] = []
|
||||
seen_genes: set[tuple[int, ...]] = set()
|
||||
total_evals = 0
|
||||
|
||||
population = [random_chromosome() for _ in range(POPULATION_SIZE)]
|
||||
new_evals = schedule_evaluations(population, cache, audio_paths)
|
||||
total_evals += new_evals
|
||||
|
||||
for gen in range(NUM_GENERATIONS):
|
||||
population.sort(key=lambda c: c.fitness, reverse=True)
|
||||
|
||||
for chrom in population:
|
||||
key = tuple(chrom.genes)
|
||||
if key not in seen_genes:
|
||||
seen_genes.add(key)
|
||||
all_configs.append(
|
||||
{
|
||||
"config": chrom.to_config(),
|
||||
"wer": chrom.wer,
|
||||
"der": chrom.der,
|
||||
"time": chrom.time_min,
|
||||
"fitness": chrom.fitness,
|
||||
"generation": gen,
|
||||
}
|
||||
)
|
||||
|
||||
best = population[0]
|
||||
mean_fit = sum(c.fitness for c in population) / len(population)
|
||||
|
||||
history.append(
|
||||
{
|
||||
"generation": gen,
|
||||
"best_fitness": round(best.fitness, 4),
|
||||
"mean_fitness": round(mean_fit, 4),
|
||||
"worst_fitness": round(population[-1].fitness, 4),
|
||||
"best_config": best.to_config(),
|
||||
"best_wer": best.wer,
|
||||
"best_der": best.der,
|
||||
"best_time": best.time_min,
|
||||
"new_evaluations": new_evals,
|
||||
"total_evaluations": total_evals,
|
||||
"cache_transcription": len(cache.transcription),
|
||||
"cache_diarization": len(cache.diarization),
|
||||
}
|
||||
)
|
||||
|
||||
print(
|
||||
f"Gen {gen:3d} | best={best.fitness:.3f} mean={mean_fit:.3f} | "
|
||||
f"WER={best.wer:.1f}% DER={best.der:.1f}% | "
|
||||
f"new={new_evals} cache_t={len(cache.transcription)} "
|
||||
f"cache_d={len(cache.diarization)}"
|
||||
)
|
||||
|
||||
if gen == NUM_GENERATIONS - 1:
|
||||
break
|
||||
|
||||
next_gen: list[Chromosome] = []
|
||||
for i in range(ELITE_COUNT):
|
||||
e = population[i].copy()
|
||||
e.fitness = population[i].fitness
|
||||
e.wer = population[i].wer
|
||||
e.der = population[i].der
|
||||
e.time_min = population[i].time_min
|
||||
next_gen.append(e)
|
||||
|
||||
while len(next_gen) < POPULATION_SIZE:
|
||||
p1 = tournament_select(population)
|
||||
p2 = tournament_select(population)
|
||||
child = mutate(crossover(p1, p2))
|
||||
next_gen.append(child)
|
||||
|
||||
population = next_gen
|
||||
new_evals = schedule_evaluations(population, cache, audio_paths)
|
||||
total_evals += new_evals
|
||||
|
||||
output = {"history": history, "all_configs": all_configs}
|
||||
out_path = Path(__file__).parent / "history.json"
|
||||
out_path.write_text(json.dumps(output, indent=2, ensure_ascii=False))
|
||||
print(f"\nResults saved to {out_path}")
|
||||
|
||||
population.sort(key=lambda c: c.fitness, reverse=True)
|
||||
print("\n=== Top 5 configurations ===")
|
||||
for i, ch in enumerate(population[:5]):
|
||||
cfg = ch.to_config()
|
||||
print(
|
||||
f"\n#{i + 1}: fitness={ch.fitness:.3f} "
|
||||
f"WER={ch.wer:.2f}% DER={ch.der:.2f}% time={ch.time_min:.2f}min"
|
||||
)
|
||||
for k, v in cfg.items():
|
||||
print(f" {k}: {v}")
|
||||
|
||||
return history
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_ga()
|
||||
Reference in New Issue
Block a user