diff --git a/lab4/gp/ops.py b/lab4/gp/ops.py index 6e9119c..6d801fa 100644 --- a/lab4/gp/ops.py +++ b/lab4/gp/ops.py @@ -1,3 +1,5 @@ +from typing import Callable, Sequence + import numpy as np from numpy.typing import NDArray @@ -5,21 +7,40 @@ from .primitive import Operation type Value = NDArray[np.float64] + +def make_safe( + fn: Callable[[Sequence[Value]], Value], +) -> Callable[[Sequence[Value]], Value]: + """Обёртка для стабилизации результатов векторных операций.""" + + def wrapped(args: Sequence[Value]) -> Value: + with np.errstate( + over="ignore", invalid="ignore", divide="ignore", under="ignore" + ): + res = fn(args) + + # гарантируем, что на выходе всегда NDArray[np.float64] + if not isinstance(res, np.ndarray): + res = np.array(res, dtype=np.float64) + + res = np.nan_to_num(res, nan=0.0, posinf=1e6, neginf=-1e6) + res = np.clip(res, -1e6, 1e6) + + return res + + return wrapped + + # Унарные операции -NEG = Operation("-", 1, lambda x: -x[0]) -SQUARE = Operation("pow2", 1, lambda x: x[0] ** 2) -SIN = Operation("sin", 1, lambda x: np.sin(x[0])) -COS = Operation("cos", 1, lambda x: np.cos(x[0])) +NEG = Operation("-", 1, make_safe(lambda x: -x[0])) +SIN = Operation("sin", 1, make_safe(lambda x: np.sin(x[0]))) +COS = Operation("cos", 1, make_safe(lambda x: np.cos(x[0]))) -def _safe_exp(v: Value) -> Value: - v_clipped = np.clip(v, -10.0, 10.0) - out = np.exp(v_clipped) - out[np.isnan(out) | np.isinf(out)] = 0.0 - return out +SQUARE = Operation("pow2", 1, make_safe(lambda x: np.clip(x[0], -1e3, 1e3) ** 2)) -EXP = Operation("exp", 1, lambda x: _safe_exp(x[0])) +EXP = Operation("exp", 1, make_safe(lambda x: np.exp(np.clip(x[0], -10, 10)))) # Бинарные операции @@ -27,30 +48,16 @@ ADD = Operation("+", 2, lambda x: x[0] + x[1]) SUB = Operation("-", 2, lambda x: x[0] - x[1]) MUL = Operation("*", 2, lambda x: x[0] * x[1]) - -def _safe_div(a: Value, b: Value) -> Value: - eps = 1e-12 - denom = np.where(np.abs(b) >= eps, b, eps) - out = np.divide(a, denom) - out = np.where(np.isnan(out) | np.isinf(out), 0.0, out) - return out - - -DIV = Operation("/", 2, lambda x: _safe_div(x[0], x[1])) - - -def _safe_pow(a: Value, b: Value) -> Value: - a_clip = np.clip(a, -1e3, 1e3) - b_clip = np.clip(b, -3.0, 3.0) - - # 0 в отрицательной степени → 0 - mask_zero_neg = (a_clip == 0.0) & (b_clip < 0.0) - with np.errstate(over="ignore", invalid="ignore", divide="ignore", under="ignore"): - out = np.power(a_clip, b_clip) - - out[mask_zero_neg] = 0.0 - out[np.isnan(out) | np.isinf(out)] = 0.0 - return out - - -POW = Operation("^", 2, lambda x: _safe_pow(x[0], x[1])) +ADD = Operation("+", 2, make_safe(lambda x: x[0] + x[1])) +SUB = Operation("-", 2, make_safe(lambda x: x[0] - x[1])) +MUL = Operation("*", 2, make_safe(lambda x: x[0] * x[1])) +DIV = Operation( + "/", + 2, + make_safe(lambda x: np.divide(x[0], np.where(np.abs(x[1]) < 1e-10, 1e-10, x[1]))), +) +POW = Operation( + "^", + 2, + make_safe(lambda x: np.power(np.clip(x[0], -1e3, 1e3), np.clip(x[1], -3, 3))), +)