safe operations
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
from typing import Callable, Sequence
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
@@ -5,21 +7,40 @@ from .primitive import Operation
|
||||
|
||||
type Value = NDArray[np.float64]
|
||||
|
||||
|
||||
def make_safe(
|
||||
fn: Callable[[Sequence[Value]], Value],
|
||||
) -> Callable[[Sequence[Value]], Value]:
|
||||
"""Обёртка для стабилизации результатов векторных операций."""
|
||||
|
||||
def wrapped(args: Sequence[Value]) -> Value:
|
||||
with np.errstate(
|
||||
over="ignore", invalid="ignore", divide="ignore", under="ignore"
|
||||
):
|
||||
res = fn(args)
|
||||
|
||||
# гарантируем, что на выходе всегда NDArray[np.float64]
|
||||
if not isinstance(res, np.ndarray):
|
||||
res = np.array(res, dtype=np.float64)
|
||||
|
||||
res = np.nan_to_num(res, nan=0.0, posinf=1e6, neginf=-1e6)
|
||||
res = np.clip(res, -1e6, 1e6)
|
||||
|
||||
return res
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
# Унарные операции
|
||||
NEG = Operation("-", 1, lambda x: -x[0])
|
||||
SQUARE = Operation("pow2", 1, lambda x: x[0] ** 2)
|
||||
SIN = Operation("sin", 1, lambda x: np.sin(x[0]))
|
||||
COS = Operation("cos", 1, lambda x: np.cos(x[0]))
|
||||
NEG = Operation("-", 1, make_safe(lambda x: -x[0]))
|
||||
SIN = Operation("sin", 1, make_safe(lambda x: np.sin(x[0])))
|
||||
COS = Operation("cos", 1, make_safe(lambda x: np.cos(x[0])))
|
||||
|
||||
|
||||
def _safe_exp(v: Value) -> Value:
|
||||
v_clipped = np.clip(v, -10.0, 10.0)
|
||||
out = np.exp(v_clipped)
|
||||
out[np.isnan(out) | np.isinf(out)] = 0.0
|
||||
return out
|
||||
SQUARE = Operation("pow2", 1, make_safe(lambda x: np.clip(x[0], -1e3, 1e3) ** 2))
|
||||
|
||||
|
||||
EXP = Operation("exp", 1, lambda x: _safe_exp(x[0]))
|
||||
EXP = Operation("exp", 1, make_safe(lambda x: np.exp(np.clip(x[0], -10, 10))))
|
||||
|
||||
|
||||
# Бинарные операции
|
||||
@@ -27,30 +48,16 @@ ADD = Operation("+", 2, lambda x: x[0] + x[1])
|
||||
SUB = Operation("-", 2, lambda x: x[0] - x[1])
|
||||
MUL = Operation("*", 2, lambda x: x[0] * x[1])
|
||||
|
||||
|
||||
def _safe_div(a: Value, b: Value) -> Value:
|
||||
eps = 1e-12
|
||||
denom = np.where(np.abs(b) >= eps, b, eps)
|
||||
out = np.divide(a, denom)
|
||||
out = np.where(np.isnan(out) | np.isinf(out), 0.0, out)
|
||||
return out
|
||||
|
||||
|
||||
DIV = Operation("/", 2, lambda x: _safe_div(x[0], x[1]))
|
||||
|
||||
|
||||
def _safe_pow(a: Value, b: Value) -> Value:
|
||||
a_clip = np.clip(a, -1e3, 1e3)
|
||||
b_clip = np.clip(b, -3.0, 3.0)
|
||||
|
||||
# 0 в отрицательной степени → 0
|
||||
mask_zero_neg = (a_clip == 0.0) & (b_clip < 0.0)
|
||||
with np.errstate(over="ignore", invalid="ignore", divide="ignore", under="ignore"):
|
||||
out = np.power(a_clip, b_clip)
|
||||
|
||||
out[mask_zero_neg] = 0.0
|
||||
out[np.isnan(out) | np.isinf(out)] = 0.0
|
||||
return out
|
||||
|
||||
|
||||
POW = Operation("^", 2, lambda x: _safe_pow(x[0], x[1]))
|
||||
ADD = Operation("+", 2, make_safe(lambda x: x[0] + x[1]))
|
||||
SUB = Operation("-", 2, make_safe(lambda x: x[0] - x[1]))
|
||||
MUL = Operation("*", 2, make_safe(lambda x: x[0] * x[1]))
|
||||
DIV = Operation(
|
||||
"/",
|
||||
2,
|
||||
make_safe(lambda x: np.divide(x[0], np.where(np.abs(x[1]) < 1e-10, 1e-10, x[1]))),
|
||||
)
|
||||
POW = Operation(
|
||||
"^",
|
||||
2,
|
||||
make_safe(lambda x: np.power(np.clip(x[0], -1e3, 1e3), np.clip(x[1], -3, 3))),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user