36 lines
1006 B
Python
36 lines
1006 B
Python
from dataclasses import dataclass
|
|
from typing import Callable, Sequence
|
|
|
|
from .types import Context, Value
|
|
|
|
type OperationFn = Callable[[Sequence[Value]], Value]
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class Primitive:
|
|
name: str
|
|
arity: int
|
|
operation_fn: OperationFn | None
|
|
|
|
def eval(self, args: Sequence[Value], context: Context) -> Value:
|
|
if self.operation_fn is None:
|
|
return context[self]
|
|
|
|
return self.operation_fn(args)
|
|
|
|
def __post_init__(self) -> None:
|
|
if self.arity != 0 and self.operation_fn is None:
|
|
raise ValueError("Operation is required for primitive with non-zero arity")
|
|
|
|
|
|
def Var(name: str) -> Primitive:
|
|
return Primitive(name=name, arity=0, operation_fn=None)
|
|
|
|
|
|
def Const(name: str, val: Value) -> Primitive:
|
|
return Primitive(name=name, arity=0, operation_fn=lambda _args: val)
|
|
|
|
|
|
def Operation(name: str, arity: int, operation_fn: OperationFn) -> Primitive:
|
|
return Primitive(name=name, arity=arity, operation_fn=operation_fn)
|