from typing import Callable, Sequence import numpy as np from numpy.typing import NDArray from .primitive import Operation type Value = NDArray[np.float64] def make_safe( fn: Callable[[Sequence[Value]], Value], ) -> Callable[[Sequence[Value]], Value]: """Обёртка для стабилизации результатов векторных операций.""" def wrapped(args: Sequence[Value]) -> Value: with np.errstate( over="ignore", invalid="ignore", divide="ignore", under="ignore" ): res = fn(args) # гарантируем, что на выходе всегда NDArray[np.float64] if not isinstance(res, np.ndarray): res = np.array(res, dtype=np.float64) res = np.nan_to_num(res, nan=0.0, posinf=1e6, neginf=-1e6) res = np.clip(res, -1e6, 1e6) return res return wrapped # Унарные операции NEG = Operation("-", 1, make_safe(lambda x: -x[0])) SIN = Operation("sin", 1, make_safe(lambda x: np.sin(x[0]))) COS = Operation("cos", 1, make_safe(lambda x: np.cos(x[0]))) SQUARE = Operation("pow2", 1, make_safe(lambda x: np.clip(x[0], -1e3, 1e3) ** 2)) EXP = Operation("exp", 1, make_safe(lambda x: np.exp(np.clip(x[0], -10, 10)))) # Бинарные операции ADD = Operation("+", 2, lambda x: x[0] + x[1]) SUB = Operation("-", 2, lambda x: x[0] - x[1]) MUL = Operation("*", 2, lambda x: x[0] * x[1]) ADD = Operation("+", 2, make_safe(lambda x: x[0] + x[1])) SUB = Operation("-", 2, make_safe(lambda x: x[0] - x[1])) MUL = Operation("*", 2, make_safe(lambda x: x[0] * x[1])) DIV = Operation( "/", 2, make_safe(lambda x: np.divide(x[0], np.where(np.abs(x[1]) < 1e-10, 1e-10, x[1]))), ) POW = Operation( "^", 2, make_safe(lambda x: np.power(np.clip(x[0], -1e3, 1e3), np.clip(x[1], -3, 3))), )