Files
dataset-tg-bot/src/database.py

554 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import sqlite3
from collections.abc import Generator
from contextlib import contextmanager
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from src.config import DB_PATH
from src.logger import logger
class UserState(Enum):
"""Состояния пользовательской сессии."""
INTRO = "intro"
NO_MORE_SCENARIOS = "no_more_scenarios"
FIRST_REPLICA = "first_replica"
SHOW_REPLICA = "show_replica"
CONFIRM_RESTART = "confirm_restart"
CONFIRM_SAVE = "confirm_save"
ASK_REPLICA_NUMBER = "ask_replica_number"
REPEAT_REPLICA = "repeat_replica"
ADMIN = "admin"
ADMIN_UPLOAD_CONFIRM = "admin_upload_confirm"
ADMIN_DELETE_CONFIRM = "admin_delete_confirm"
@dataclass
class User:
"""Пользователь бота (диктор в датасете)."""
id: int # dataset_speaker_id
telegram_id: int
created_at: datetime
@dataclass
class Scenario:
"""Сценарий совещания."""
id: str # scenario_id из имени файла
created_at: datetime
@dataclass
class Replica:
"""Реплика в сценарии."""
id: int
scenario_id: str
speaker_id: int # в рамках сценария
replica_index: int # порядок в сценарии (0-indexed)
text: str
@dataclass
class Recording:
"""Запись озвучки реплики."""
id: int
user_id: int # dataset_speaker_id
scenario_id: str
replica_index: int
created_at: datetime
@dataclass
class UserSession:
"""Состояние сессии пользователя."""
user_id: int
state: UserState
scenario_id: str | None
speaker_id: int | None # speaker_id в текущем сценарии
replica_index: int | None
previous_state: UserState | None # для возврата из ADMIN
last_bot_message_id: int | None # для удаления кнопок
@contextmanager
def get_connection() -> Generator[sqlite3.Connection, None, None]:
"""Возвращает соединение с БД."""
conn = sqlite3.connect(DB_PATH, detect_types=sqlite3.PARSE_DECLTYPES)
conn.row_factory = sqlite3.Row
try:
yield conn
finally:
conn.close()
def init_db() -> None:
"""Инициализирует схему базы данных."""
logger.info("Инициализация базы данных...")
with get_connection() as conn:
conn.executescript("""
CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
telegram_id INTEGER UNIQUE NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
CREATE TABLE IF NOT EXISTS scenarios (
id TEXT PRIMARY KEY,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
CREATE TABLE IF NOT EXISTS replicas (
id INTEGER PRIMARY KEY AUTOINCREMENT,
scenario_id TEXT NOT NULL,
speaker_id INTEGER NOT NULL,
replica_index INTEGER NOT NULL,
text TEXT NOT NULL,
FOREIGN KEY (scenario_id) REFERENCES scenarios(id),
UNIQUE(scenario_id, replica_index)
);
CREATE TABLE IF NOT EXISTS recordings (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL,
scenario_id TEXT NOT NULL,
replica_index INTEGER NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (user_id) REFERENCES users(id),
FOREIGN KEY (scenario_id) REFERENCES scenarios(id),
UNIQUE(user_id, scenario_id, replica_index)
);
CREATE TABLE IF NOT EXISTS user_sessions (
user_id INTEGER PRIMARY KEY,
state TEXT NOT NULL,
scenario_id TEXT,
speaker_id INTEGER,
replica_index INTEGER,
previous_state TEXT,
last_bot_message_id INTEGER,
FOREIGN KEY (user_id) REFERENCES users(id),
FOREIGN KEY (scenario_id) REFERENCES scenarios(id)
);
CREATE INDEX IF NOT EXISTS idx_replicas_scenario
ON replicas(scenario_id);
CREATE INDEX IF NOT EXISTS idx_replicas_scenario_speaker
ON replicas(scenario_id, speaker_id);
CREATE INDEX IF NOT EXISTS idx_recordings_user
ON recordings(user_id);
CREATE INDEX IF NOT EXISTS idx_recordings_scenario
ON recordings(scenario_id);
""")
conn.commit()
logger.info("База данных инициализирована")
# === Users CRUD ===
def get_or_create_user(telegram_id: int) -> User:
"""Получает или создаёт пользователя по telegram_id."""
with get_connection() as conn:
cursor = conn.execute(
"SELECT id, telegram_id, created_at FROM users WHERE telegram_id = ?",
(telegram_id,),
)
row = cursor.fetchone()
if row:
return User(
id=row["id"],
telegram_id=row["telegram_id"],
created_at=row["created_at"],
)
cursor = conn.execute(
"INSERT INTO users (telegram_id) VALUES (?) "
"RETURNING id, telegram_id, created_at",
(telegram_id,),
)
row = cursor.fetchone()
conn.commit()
logger.info(f"Создан новый пользователь: dataset_speaker_id={row['id']}")
return User(
id=row["id"], telegram_id=row["telegram_id"], created_at=row["created_at"]
)
def get_user_by_telegram_id(telegram_id: int) -> User | None:
"""Получает пользователя по telegram_id."""
with get_connection() as conn:
cursor = conn.execute(
"SELECT id, telegram_id, created_at FROM users WHERE telegram_id = ?",
(telegram_id,),
)
row = cursor.fetchone()
if row:
return User(
id=row["id"],
telegram_id=row["telegram_id"],
created_at=row["created_at"],
)
return None
# === Scenarios CRUD ===
def create_scenario(scenario_id: str) -> Scenario:
"""Создаёт новый сценарий."""
with get_connection() as conn:
cursor = conn.execute(
"INSERT INTO scenarios (id) VALUES (?) RETURNING id, created_at",
(scenario_id,),
)
row = cursor.fetchone()
conn.commit()
logger.info(f"Создан сценарий: {scenario_id}")
return Scenario(id=row["id"], created_at=row["created_at"])
def get_scenario(scenario_id: str) -> Scenario | None:
"""Получает сценарий по id."""
with get_connection() as conn:
cursor = conn.execute(
"SELECT id, created_at FROM scenarios WHERE id = ?",
(scenario_id,),
)
row = cursor.fetchone()
if row:
return Scenario(id=row["id"], created_at=row["created_at"])
return None
def get_all_scenarios() -> list[Scenario]:
"""Получает все сценарии."""
with get_connection() as conn:
cursor = conn.execute(
"SELECT id, created_at FROM scenarios ORDER BY created_at"
)
return [
Scenario(id=row["id"], created_at=row["created_at"])
for row in cursor.fetchall()
]
# === Replicas CRUD ===
def create_replicas(scenario_id: str, replicas: list[tuple[int, int, str]]) -> None:
"""Создаёт реплики. replicas: [(speaker_id, replica_index, text), ...]"""
with get_connection() as conn:
conn.executemany(
"INSERT INTO replicas (scenario_id, speaker_id, replica_index, text) "
"VALUES (?, ?, ?, ?)",
[
(scenario_id, speaker_id, idx, text)
for speaker_id, idx, text in replicas
],
)
conn.commit()
def get_replicas_for_scenario(scenario_id: str) -> list[Replica]:
"""Получает все реплики сценария."""
with get_connection() as conn:
cursor = conn.execute(
"SELECT id, scenario_id, speaker_id, replica_index, text FROM replicas "
"WHERE scenario_id = ? ORDER BY replica_index",
(scenario_id,),
)
return [
Replica(
id=row["id"],
scenario_id=row["scenario_id"],
speaker_id=row["speaker_id"],
replica_index=row["replica_index"],
text=row["text"],
)
for row in cursor.fetchall()
]
def get_replicas_for_track(scenario_id: str, speaker_id: int) -> list[Replica]:
"""Получает реплики для конкретной дорожки (speaker_id в сценарии)."""
with get_connection() as conn:
cursor = conn.execute(
"SELECT id, scenario_id, speaker_id, replica_index, text FROM replicas "
"WHERE scenario_id = ? AND speaker_id = ? ORDER BY replica_index",
(scenario_id, speaker_id),
)
return [
Replica(
id=row["id"],
scenario_id=row["scenario_id"],
speaker_id=row["speaker_id"],
replica_index=row["replica_index"],
text=row["text"],
)
for row in cursor.fetchall()
]
def get_track_speaker_ids(scenario_id: str) -> list[int]:
"""Получает список speaker_id (дорожек) в сценарии."""
with get_connection() as conn:
cursor = conn.execute(
"SELECT DISTINCT speaker_id FROM replicas "
"WHERE scenario_id = ? ORDER BY speaker_id",
(scenario_id,),
)
return [row["speaker_id"] for row in cursor.fetchall()]
# === Recordings CRUD ===
def create_recording(user_id: int, scenario_id: str, replica_index: int) -> Recording:
"""Создаёт запись об озвучке реплики."""
with get_connection() as conn:
cursor = conn.execute(
"INSERT INTO recordings (user_id, scenario_id, replica_index) "
"VALUES (?, ?, ?) "
"RETURNING id, user_id, scenario_id, replica_index, created_at",
(user_id, scenario_id, replica_index),
)
row = cursor.fetchone()
conn.commit()
return Recording(
id=row["id"],
user_id=row["user_id"],
scenario_id=row["scenario_id"],
replica_index=row["replica_index"],
created_at=row["created_at"],
)
def upsert_recording(user_id: int, scenario_id: str, replica_index: int) -> Recording:
"""Создаёт или обновляет запись об озвучке реплики."""
with get_connection() as conn:
cursor = conn.execute(
"""
INSERT INTO recordings (user_id, scenario_id, replica_index)
VALUES (?, ?, ?)
ON CONFLICT(user_id, scenario_id, replica_index)
DO UPDATE SET created_at = CURRENT_TIMESTAMP
RETURNING id, user_id, scenario_id, replica_index, created_at
""",
(user_id, scenario_id, replica_index),
)
row = cursor.fetchone()
conn.commit()
return Recording(
id=row["id"],
user_id=row["user_id"],
scenario_id=row["scenario_id"],
replica_index=row["replica_index"],
created_at=row["created_at"],
)
def get_user_recordings_for_scenario(user_id: int, scenario_id: str) -> list[Recording]:
"""Получает все записи пользователя для сценария."""
with get_connection() as conn:
cursor = conn.execute(
"SELECT id, user_id, scenario_id, replica_index, created_at "
"FROM recordings WHERE user_id = ? AND scenario_id = ? "
"ORDER BY replica_index",
(user_id, scenario_id),
)
return [
Recording(
id=row["id"],
user_id=row["user_id"],
scenario_id=row["scenario_id"],
replica_index=row["replica_index"],
created_at=row["created_at"],
)
for row in cursor.fetchall()
]
def delete_user_recordings_for_scenario(user_id: int, scenario_id: str) -> None:
"""Удаляет все записи пользователя для сценария."""
with get_connection() as conn:
conn.execute(
"DELETE FROM recordings WHERE user_id = ? AND scenario_id = ?",
(user_id, scenario_id),
)
conn.commit()
# === User Sessions CRUD ===
def get_user_session(user_id: int) -> UserSession | None:
"""Получает сессию пользователя."""
with get_connection() as conn:
cursor = conn.execute(
"SELECT user_id, state, scenario_id, speaker_id, replica_index, "
"previous_state, last_bot_message_id FROM user_sessions WHERE user_id = ?",
(user_id,),
)
row = cursor.fetchone()
if row:
return UserSession(
user_id=row["user_id"],
state=UserState(row["state"]),
scenario_id=row["scenario_id"],
speaker_id=row["speaker_id"],
replica_index=row["replica_index"],
previous_state=UserState(row["previous_state"])
if row["previous_state"]
else None,
last_bot_message_id=row["last_bot_message_id"],
)
return None
def upsert_user_session(session: UserSession) -> None:
"""Создаёт или обновляет сессию пользователя."""
with get_connection() as conn:
conn.execute(
"""
INSERT INTO user_sessions (user_id, state, scenario_id, speaker_id,
replica_index, previous_state, last_bot_message_id)
VALUES (?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(user_id) DO UPDATE SET
state = excluded.state,
scenario_id = excluded.scenario_id,
speaker_id = excluded.speaker_id,
replica_index = excluded.replica_index,
previous_state = excluded.previous_state,
last_bot_message_id = excluded.last_bot_message_id
""",
(
session.user_id,
session.state.value,
session.scenario_id,
session.speaker_id,
session.replica_index,
session.previous_state.value if session.previous_state else None,
session.last_bot_message_id,
),
)
conn.commit()
def get_users_in_state(state: UserState) -> list[int]:
"""Получает список user_id пользователей в указанном состоянии."""
with get_connection() as conn:
cursor = conn.execute(
"SELECT user_id FROM user_sessions WHERE state = ?",
(state.value,),
)
return [row["user_id"] for row in cursor.fetchall()]
# === Statistics ===
def get_stats() -> dict:
"""Получает статистику датасета для админки."""
with get_connection() as conn:
stats = {}
# Общее количество сценариев
stats["total_scenarios"] = conn.execute(
"SELECT COUNT(*) FROM scenarios"
).fetchone()[0]
# Общее количество реплик
stats["total_replicas"] = conn.execute(
"SELECT COUNT(*) FROM replicas"
).fetchone()[0]
# Общее количество дорожек
stats["total_tracks"] = conn.execute(
"SELECT COUNT(DISTINCT scenario_id || '-' || speaker_id) FROM replicas"
).fetchone()[0]
# Количество уникальных пользователей
stats["total_users"] = conn.execute("SELECT COUNT(*) FROM users").fetchone()[0]
# Количество озвученных реплик
stats["total_recordings"] = conn.execute(
"SELECT COUNT(*) FROM recordings"
).fetchone()[0]
# Количество полностью озвученных дорожек
stats["completed_tracks"] = conn.execute("""
SELECT COUNT(*) FROM (
SELECT r.user_id, r.scenario_id, rep.speaker_id, COUNT(*) as cnt
FROM recordings r
JOIN replicas rep ON r.scenario_id = rep.scenario_id
AND r.replica_index = rep.replica_index
GROUP BY r.user_id, r.scenario_id, rep.speaker_id
HAVING cnt = (
SELECT COUNT(*) FROM replicas rp
WHERE rp.scenario_id = r.scenario_id
AND rp.speaker_id = rep.speaker_id
)
)
""").fetchone()[0]
return stats
def get_scenario_stats(scenario_id: str) -> dict:
"""Получает статистику конкретного сценария."""
with get_connection() as conn:
stats = {}
# Количество реплик
stats["total_replicas"] = conn.execute(
"SELECT COUNT(*) FROM replicas WHERE scenario_id = ?", (scenario_id,)
).fetchone()[0]
# Количество дорожек
stats["total_tracks"] = conn.execute(
"SELECT COUNT(DISTINCT speaker_id) FROM replicas WHERE scenario_id = ?",
(scenario_id,),
).fetchone()[0]
# Количество записей
stats["total_recordings"] = conn.execute(
"SELECT COUNT(*) FROM recordings WHERE scenario_id = ?", (scenario_id,)
).fetchone()[0]
return stats
def get_users_with_scenario(scenario_id: str) -> list[tuple[int, int]]:
"""Получает пользователей, озвучивающих сценарий."""
with get_connection() as conn:
cursor = conn.execute(
"""
SELECT DISTINCT u.id, u.telegram_id
FROM user_sessions us
JOIN users u ON us.user_id = u.id
WHERE us.scenario_id = ?
""",
(scenario_id,),
)
return [(row["id"], row["telegram_id"]) for row in cursor.fetchall()]
def delete_scenario_data(scenario_id: str) -> None:
"""Удаляет сценарий и все связанные данные из БД."""
with get_connection() as conn:
conn.execute("DELETE FROM recordings WHERE scenario_id = ?", (scenario_id,))
conn.execute("DELETE FROM replicas WHERE scenario_id = ?", (scenario_id,))
conn.execute("DELETE FROM scenarios WHERE id = ?", (scenario_id,))
conn.commit()
logger.info(f"Deleted scenario {scenario_id} from database")