Files
genetic-algorithms/ga/run_pipeline.py
2026-04-04 16:00:25 +03:00

154 lines
4.5 KiB
Python

"""Pipeline evaluation adapter.
Provides batch evaluation functions for transcription and diarization modules.
Currently contains simulation stubs with realistic performance models based on
published benchmarks. Replace the simulation logic with actual pipeline calls
for production use.
"""
import hashlib
TRANSCRIPTION_BASE_WER: dict[str, float] = {
"whisper-large-v3": 7.8,
"whisper-medium": 13.5,
"faster-whisper-large-v3": 7.6,
"gigaam-ctc": 6.8,
"gigaam-rnnt": 5.4,
}
TRANSCRIPTION_BASE_TIME: dict[str, float] = {
"whisper-large-v3": 4.2,
"whisper-medium": 2.8,
"faster-whisper-large-v3": 2.2,
"gigaam-ctc": 1.5,
"gigaam-rnnt": 3.5,
}
WHISPER_MODELS = {"whisper-large-v3", "whisper-medium", "faster-whisper-large-v3"}
BEAM_SIZE_WER_DELTA = {1: 1.2, 3: 0.4, 5: 0.0, 7: -0.1, 10: -0.15}
BEAM_SIZE_TIME_FACTOR = {1: 0.6, 3: 0.8, 5: 1.0, 7: 1.15, 10: 1.4}
VAD_WER_DELTA = {0.3: 0.8, 0.4: 0.2, 0.5: 0.0, 0.6: 0.3, 0.7: 1.0}
DIARIZATION_BASE_DER: dict[str, float] = {
"pyannote-3.1": 24.0,
"pyannote-community-1": 20.5,
"sortformer": 18.8,
}
DIARIZATION_BASE_TIME: dict[str, float] = {
"pyannote-3.1": 2.5,
"pyannote-community-1": 2.8,
"sortformer": 3.8,
}
MIN_SPEECH_DER_DELTA = {0.25: 1.5, 0.5: 0.0, 0.75: 0.3, 1.0: 1.2, 1.5: 3.0}
CLUSTERING_DER_DELTA = {0.3: 3.0, 0.45: 0.8, 0.6: 0.0, 0.75: 0.5, 0.9: 2.5}
VAD_DER_DELTA = {0.3: 1.0, 0.4: 0.3, 0.5: 0.0, 0.6: 0.5, 0.7: 1.5}
def _deterministic_noise(seed_str: str, amplitude: float = 0.3) -> float:
h = int(hashlib.md5(seed_str.encode()).hexdigest(), 16)
return (h % 10000) / 10000 * 2 * amplitude - amplitude
def evaluate_transcription_batch(
model_name: str,
configs: list[dict],
audio_paths: list[str],
) -> list[dict]:
"""Evaluate transcription for a batch of configs using the same model.
In production, this loads the model once and iterates over configs.
Currently returns simulated results.
Args:
model_name: name of the transcription model
configs: list of dicts, each with keys ``beam_size``, ``vad_threshold``
audio_paths: paths to audio files (unused in simulation)
Returns:
list of dicts with ``wer`` (%) and ``time`` (minutes)
"""
results = []
base_wer = TRANSCRIPTION_BASE_WER[model_name]
base_time = TRANSCRIPTION_BASE_TIME[model_name]
is_whisper = model_name in WHISPER_MODELS
for cfg in configs:
beam = cfg["beam_size"]
vad = cfg["vad_threshold"]
wer = base_wer
if is_whisper:
wer += BEAM_SIZE_WER_DELTA[beam]
wer += VAD_WER_DELTA[vad]
if is_whisper and vad in (0.3, 0.7) and beam >= 7:
wer += 0.4
noise = _deterministic_noise(f"t_{model_name}_{beam}_{vad}")
wer = max(1.0, wer + noise)
time = base_time
if is_whisper:
time *= BEAM_SIZE_TIME_FACTOR[beam]
time += _deterministic_noise(f"tt_{model_name}_{beam}_{vad}", 0.1)
time = max(0.5, time)
results.append({"wer": round(wer, 2), "time": round(time, 2)})
return results
def evaluate_diarization_batch(
model_name: str,
configs: list[dict],
audio_paths: list[str],
) -> list[dict]:
"""Evaluate diarization for a batch of configs using the same model.
In production, this loads the model once and iterates over configs.
Currently returns simulated results.
Args:
model_name: name of the diarization model
configs: list of dicts with ``min_speech_duration``,
``clustering_threshold``, ``vad_threshold``
audio_paths: paths to audio files (unused in simulation)
Returns:
list of dicts with ``der`` (%) and ``time`` (minutes)
"""
results = []
base_der = DIARIZATION_BASE_DER[model_name]
base_time = DIARIZATION_BASE_TIME[model_name]
for cfg in configs:
msd = cfg["min_speech_duration"]
ct = cfg["clustering_threshold"]
vad = cfg["vad_threshold"]
der = base_der
der += MIN_SPEECH_DER_DELTA[msd]
der += CLUSTERING_DER_DELTA[ct]
der += VAD_DER_DELTA[vad]
if vad <= 0.3 and msd <= 0.25:
der += 1.2
if ct >= 0.9 and msd >= 1.5:
der += 0.8
noise = _deterministic_noise(f"d_{model_name}_{msd}_{ct}_{vad}")
der = max(5.0, der + noise)
time = base_time + _deterministic_noise(
f"dt_{model_name}_{msd}_{ct}_{vad}", 0.15
)
time = max(0.5, time)
results.append({"der": round(der, 2), "time": round(time, 2)})
return results