64 lines
1.8 KiB
Python
64 lines
1.8 KiB
Python
from typing import Callable, Sequence
|
|
|
|
import numpy as np
|
|
from numpy.typing import NDArray
|
|
|
|
from .primitive import Operation
|
|
|
|
type Value = NDArray[np.float64]
|
|
|
|
|
|
def make_safe(
|
|
fn: Callable[[Sequence[Value]], Value],
|
|
) -> Callable[[Sequence[Value]], Value]:
|
|
"""Обёртка для стабилизации результатов векторных операций."""
|
|
|
|
def wrapped(args: Sequence[Value]) -> Value:
|
|
with np.errstate(
|
|
over="ignore", invalid="ignore", divide="ignore", under="ignore"
|
|
):
|
|
res = fn(args)
|
|
|
|
# гарантируем, что на выходе всегда NDArray[np.float64]
|
|
if not isinstance(res, np.ndarray):
|
|
res = np.array(res, dtype=np.float64)
|
|
|
|
res = np.nan_to_num(res, nan=0.0, posinf=1e6, neginf=-1e6)
|
|
res = np.clip(res, -1e6, 1e6)
|
|
|
|
return res
|
|
|
|
return wrapped
|
|
|
|
|
|
# Унарные операции
|
|
NEG = Operation("-", 1, make_safe(lambda x: -x[0]))
|
|
SIN = Operation("sin", 1, make_safe(lambda x: np.sin(x[0])))
|
|
COS = Operation("cos", 1, make_safe(lambda x: np.cos(x[0])))
|
|
|
|
|
|
SQUARE = Operation("pow2", 1, make_safe(lambda x: np.clip(x[0], -1e3, 1e3) ** 2))
|
|
|
|
|
|
EXP = Operation("exp", 1, make_safe(lambda x: np.exp(np.clip(x[0], -10, 10))))
|
|
|
|
|
|
# Бинарные операции
|
|
ADD = Operation("+", 2, lambda x: x[0] + x[1])
|
|
SUB = Operation("-", 2, lambda x: x[0] - x[1])
|
|
MUL = Operation("*", 2, lambda x: x[0] * x[1])
|
|
|
|
ADD = Operation("+", 2, make_safe(lambda x: x[0] + x[1]))
|
|
SUB = Operation("-", 2, make_safe(lambda x: x[0] - x[1]))
|
|
MUL = Operation("*", 2, make_safe(lambda x: x[0] * x[1]))
|
|
DIV = Operation(
|
|
"/",
|
|
2,
|
|
make_safe(lambda x: np.divide(x[0], np.where(np.abs(x[1]) < 1e-10, 1e-10, x[1]))),
|
|
)
|
|
POW = Operation(
|
|
"^",
|
|
2,
|
|
make_safe(lambda x: np.power(np.clip(x[0], -1e3, 1e3), np.clip(x[1], -3, 3))),
|
|
)
|