Files
dataset-tg-bot/src/scenarios.py

329 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import shutil
from pathlib import Path
from src.config import DATA_DIR, DATA_PARTIAL_DIR
from src.database import (
UserState,
create_replicas,
create_scenario,
get_connection,
get_replicas_for_track,
get_scenario,
)
from src.logger import logger
# Состояния, в которых пользователь активно работает с дорожкой
_RECORDING_STATES = (
UserState.FIRST_REPLICA.value,
UserState.SHOW_REPLICA.value,
UserState.CONFIRM_RESTART.value,
UserState.CONFIRM_SAVE.value,
UserState.ASK_REPLICA_NUMBER.value,
UserState.REPEAT_REPLICA.value,
)
def load_scenario_from_json(scenario_id: str, json_data: list[dict]) -> int:
"""Загружает сценарий из JSON. Возвращает количество реплик."""
if get_scenario(scenario_id):
raise ValueError(f"Сценарий {scenario_id} уже существует")
replicas_data: list[tuple[int, int, str, str]] = []
for idx, item in enumerate(json_data):
if "text" not in item or "speaker_id" not in item or "gender" not in item:
raise ValueError(f"Некорректный формат реплики #{idx}")
replicas_data.append((item["speaker_id"], idx, item["text"], item["gender"]))
create_scenario(scenario_id)
create_replicas(scenario_id, replicas_data)
logger.info(f"Загружен сценарий {scenario_id}: {len(replicas_data)} реплик")
return len(replicas_data)
def parse_scenario_file(file_content: bytes) -> tuple[list[dict], str | None]:
"""Парсит JSON-файл сценария. Возвращает (данные, ошибка)."""
try:
data = json.loads(file_content.decode("utf-8"))
except (json.JSONDecodeError, UnicodeDecodeError) as e:
return [], f"Ошибка парсинга JSON: {e}"
if not isinstance(data, list):
return [], "Ожидается массив реплик"
if not data:
return [], "Сценарий пуст"
for idx, item in enumerate(data):
if not isinstance(item, dict):
return [], f"Реплика #{idx} должна быть объектом"
if "text" not in item:
return [], f"Реплика #{idx}: отсутствует поле 'text'"
if "speaker_id" not in item:
return [], f"Реплика #{idx}: отсутствует поле 'speaker_id'"
if not isinstance(item["speaker_id"], int):
return [], f"Реплика #{idx}: 'speaker_id' должен быть числом"
if "gender" not in item:
return [], f"Реплика #{idx}: отсутствует поле 'gender'"
if item["gender"] not in ["male", "female"]:
return (
[],
f"Реплика #{idx}: 'gender' должен быть 'male' или 'female'",
)
return data, None
def get_scenario_info(json_data: list[dict]) -> dict:
"""Возвращает информацию о сценарии для превью."""
speaker_ids = set(item["speaker_id"] for item in json_data)
return {
"total_replicas": len(json_data),
"total_tracks": len(speaker_ids),
"speaker_ids": sorted(speaker_ids),
}
def find_available_track(user_id: int) -> tuple[str, int] | None:
"""
Находит доступную дорожку для пользователя с учётом пола.
Приоритет:
1. Дорожки, которые никто не начал озвучивать
2. Дорожки, которые кто-то начал, но не закончил
3. Дорожки с готовой озвучкой (для дополнительных записей)
Пользователь не может озвучивать две разные дорожки в одном сценарии.
Учитывается пол пользователя и пол дорожек.
Возвращает (scenario_id, speaker_id) или None.
"""
with get_connection() as conn:
# Получаем пол пользователя
user_gender_row = conn.execute(
"SELECT gender FROM users WHERE id = ?", (user_id,)
).fetchone()
if not user_gender_row or not user_gender_row[0]:
return None
user_gender = user_gender_row[0]
# Сценарии, в которых пользователь уже записывает дорожку
user_scenarios = conn.execute(
"""
SELECT DISTINCT scenario_id FROM recordings WHERE user_id = ?
""",
(user_id,),
).fetchall()
user_scenario_ids = {row[0] for row in user_scenarios}
# Все дорожки (scenario_id, speaker_id) с количеством реплик и полом
# Учитываем только дорожки, где пол совпадает с полом пользователя
all_tracks = conn.execute(
"""
SELECT scenario_id, speaker_id, COUNT(*) as replica_count, gender
FROM replicas
WHERE gender = ?
GROUP BY scenario_id, speaker_id
""",
(user_gender,),
).fetchall()
# Статистика записей по дорожкам
track_stats = conn.execute(
"""
SELECT r.scenario_id, rep.speaker_id, r.user_id, COUNT(*) as recorded_count
FROM recordings r
JOIN replicas rep ON r.scenario_id = rep.scenario_id
AND r.replica_index = rep.replica_index
GROUP BY r.scenario_id, rep.speaker_id, r.user_id
"""
).fetchall()
# Словарь: (scenario_id, speaker_id) -> {user_id: recorded_count}
track_recordings: dict[tuple[str, int], dict[int, int]] = {}
for row in track_stats:
key = (row[0], row[1])
if key not in track_recordings:
track_recordings[key] = {}
track_recordings[key][row[2]] = row[3]
# Дорожки, назначенные другим пользователям через активные сессии
placeholders = ",".join("?" * len(_RECORDING_STATES))
session_rows = conn.execute(
f"""
SELECT scenario_id, speaker_id FROM user_sessions
WHERE user_id != ?
AND scenario_id IS NOT NULL
AND speaker_id IS NOT NULL
AND state IN ({placeholders})
""",
(user_id, *_RECORDING_STATES),
).fetchall()
assigned_tracks: set[tuple[str, int]] = {
(row[0], row[1]) for row in session_rows
}
# Категоризация дорожек
untouched: list[tuple[str, int]] = [] # никто не начал
in_progress: list[tuple[str, int]] = [] # начато, не закончено
completed: list[tuple[str, int]] = [] # есть готовая запись
for row in all_tracks:
scenario_id, speaker_id, replica_count = row[0], row[1], row[2]
key = (scenario_id, speaker_id)
# Пропускаем сценарии, где пользователь уже записывает другую дорожку
if scenario_id in user_scenario_ids:
# Проверяем, записывает ли он именно эту дорожку
if key in track_recordings and user_id in track_recordings[key]:
# Пользователь уже записывает эту дорожку — пропускаем
continue
# Пользователь записывает другую дорожку в этом сценарии
continue
if key not in track_recordings:
if key in assigned_tracks:
in_progress.append(key)
else:
untouched.append(key)
else:
has_complete = any(
count == replica_count for count in track_recordings[key].values()
)
if has_complete:
completed.append(key)
else:
in_progress.append(key)
# Выбираем по приоритету
if untouched:
return untouched[0]
if in_progress:
return in_progress[0]
if completed:
return completed[0]
return None
def get_user_current_track(user_id: int) -> tuple[str, int] | None:
"""Возвращает текущую дорожку пользователя (scenario_id, speaker_id) или None."""
with get_connection() as conn:
row = conn.execute(
"""
SELECT r.scenario_id, rep.speaker_id
FROM recordings r
JOIN replicas rep ON r.scenario_id = rep.scenario_id
AND r.replica_index = rep.replica_index
WHERE r.user_id = ?
GROUP BY r.scenario_id, rep.speaker_id
ORDER BY MAX(r.created_at) DESC
LIMIT 1
""",
(user_id,),
).fetchone()
if row:
return (row[0], row[1])
return None
def is_track_complete(user_id: int, scenario_id: str, speaker_id: int) -> bool:
"""Проверяет, полностью ли озвучена дорожка пользователем."""
track_replicas = get_replicas_for_track(scenario_id, speaker_id)
with get_connection() as conn:
recorded = conn.execute(
"""
SELECT COUNT(*) FROM recordings r
JOIN replicas rep ON r.scenario_id = rep.scenario_id
AND r.replica_index = rep.replica_index
WHERE r.user_id = ? AND r.scenario_id = ? AND rep.speaker_id = ?
""",
(user_id, scenario_id, speaker_id),
).fetchone()[0]
return recorded == len(track_replicas)
def get_partial_dir(scenario_id: str) -> Path:
"""Возвращает путь к папке частичных записей сценария."""
return DATA_PARTIAL_DIR / scenario_id
def get_data_dir(scenario_id: str) -> Path:
"""Возвращает путь к папке готовых записей сценария."""
return DATA_DIR / scenario_id
def get_audio_filename(replica_index: int, speaker_id: int, user_id: int) -> str:
"""Формирует имя файла для аудиозаписи."""
return f"r{replica_index:03d}_s{speaker_id:02d}_u{user_id:03d}.ogg"
def move_track_to_data(user_id: int, scenario_id: str, speaker_id: int) -> None:
"""Переносит завершённую дорожку из data_partial в data."""
partial_dir = get_partial_dir(scenario_id)
data_dir = get_data_dir(scenario_id)
data_dir.mkdir(parents=True, exist_ok=True)
track_replicas = get_replicas_for_track(scenario_id, speaker_id)
moved_count = 0
for replica in track_replicas:
filename = get_audio_filename(replica.replica_index, speaker_id, user_id)
src = partial_dir / filename
dst = data_dir / filename
if src.exists():
shutil.move(str(src), str(dst))
moved_count += 1
# Удаляем пустую папку partial
if partial_dir.exists() and not any(partial_dir.iterdir()):
partial_dir.rmdir()
if moved_count > 0:
logger.info(
f"Перенесено {moved_count} файлов для дорожки "
f"{scenario_id}/{speaker_id} (user_id={user_id})"
)
def delete_partial_track(user_id: int, scenario_id: str, speaker_id: int) -> None:
"""Удаляет частичные записи дорожки."""
partial_dir = get_partial_dir(scenario_id)
track_replicas = get_replicas_for_track(scenario_id, speaker_id)
deleted_count = 0
for replica in track_replicas:
filename = get_audio_filename(replica.replica_index, speaker_id, user_id)
filepath = partial_dir / filename
if filepath.exists():
filepath.unlink()
deleted_count += 1
if deleted_count > 0:
logger.info(
f"Удалено {deleted_count} частичных записей для дорожки "
f"{scenario_id}/{speaker_id} (user_id={user_id})"
)
def delete_scenario_files(scenario_id: str) -> int:
"""Удаляет все файлы сценария из data/ и data_partial/. Возвращает число файлов."""
deleted_count = 0
for base_dir in [DATA_DIR, DATA_PARTIAL_DIR]:
scenario_dir = base_dir / scenario_id
if scenario_dir.exists():
for file in scenario_dir.glob("*.wav"):
file.unlink()
deleted_count += 1
# Удаляем пустую папку
try:
scenario_dir.rmdir()
except OSError:
pass # Папка не пуста
if deleted_count > 0:
logger.info(f"Удалено {deleted_count} файлов для сценария {scenario_id}")
return deleted_count