vectorized

This commit is contained in:
2025-11-05 20:32:09 +03:00
parent 26bd6da1b4
commit e6765c9254
8 changed files with 140 additions and 349 deletions

View File

@@ -1,20 +1,21 @@
import math
import numpy as np
from numpy.typing import NDArray
from .primitive import Operation
type Value = NDArray[np.float64]
# Унарные операции
NEG = Operation("-", 1, lambda x: -x[0])
SIN = Operation("sin", 1, lambda x: math.sin(x[0]))
COS = Operation("cos", 1, lambda x: math.cos(x[0]))
SIN = Operation("sin", 1, lambda x: np.sin(x[0]))
COS = Operation("cos", 1, lambda x: np.cos(x[0]))
def _safe_exp(a: float) -> float:
if a < 700.0:
if a > -700.0:
return math.exp(a)
return 0.0
else:
return float("inf")
def _safe_exp(v: Value) -> Value:
v_clipped = np.clip(v, -10.0, 10.0)
out = np.exp(v_clipped)
out[np.isnan(out) | np.isinf(out)] = 0.0
return out
EXP = Operation("exp", 1, lambda x: _safe_exp(x[0]))
@@ -24,28 +25,34 @@ EXP = Operation("exp", 1, lambda x: _safe_exp(x[0]))
ADD = Operation("+", 2, lambda x: x[0] + x[1])
SUB = Operation("-", 2, lambda x: x[0] - x[1])
MUL = Operation("*", 2, lambda x: x[0] * x[1])
DIV = Operation("/", 2, lambda x: x[0] / x[1] if x[1] != 0 else float("inf"))
def safe_pow(a, b):
# 0 в отрицательной степени
if abs(a) <= 1e-12 and b < 0:
return float("inf")
# отрицательное основание при нецелой степени
if a < 0 and abs(b - round(b)) > 1e-12:
return float("inf")
# грубое насыщение (настрой пороги под задачу)
if abs(a) > 1 and b > 20:
return float("inf")
if abs(a) < 1 and b < -20:
return float("inf")
try:
return a**b
except (OverflowError, ValueError):
return float("inf")
def _safe_div(a: Value, b: Value) -> Value:
eps = 1e-12
denom = np.where(np.abs(b) >= eps, b, eps)
out = np.divide(a, denom)
out = np.where(np.isnan(out) | np.isinf(out), 0.0, out)
return out
POW = Operation("^", 2, lambda x: safe_pow(x[0], x[1]))
DIV = Operation("/", 2, lambda x: _safe_div(x[0], x[1]))
def _safe_pow(a: Value, b: Value) -> Value:
a_clip = np.clip(a, -1e3, 1e3)
b_clip = np.clip(b, -3.0, 3.0)
# 0 в отрицательной степени → 0
mask_zero_neg = (a_clip == 0.0) & (b_clip < 0.0)
with np.errstate(over="ignore", invalid="ignore", divide="ignore", under="ignore"):
out = np.power(a_clip, b_clip)
out[mask_zero_neg] = 0.0
out[np.isnan(out) | np.isinf(out)] = 0.0
return out
POW = Operation("^", 2, lambda x: _safe_pow(x[0], x[1]))
# Все операции в либе
ALL = (NEG, SIN, COS, EXP, ADD, SUB, MUL, DIV, POW)