Source code for copul.wrapper.sympy_wrapper

import numpy as np
import sympy
from typing import Union, Dict, Any, Optional, Tuple, Set
from sympy.utilities.lambdify import lambdify


[docs] class SymPyFuncWrapper: """ A wrapper class for SymPy expressions that provides additional functionality and a more intuitive interface for working with symbolic mathematics. This class allows for easier function substitution, differentiation, and arithmetic operations with proper type handling. """ def __init__(self, sympy_func: Union["SymPyFuncWrapper", sympy.Expr, float, int]): """ Initialize a SymPyFuncWrapper with a SymPy expression or numeric value. Args: sympy_func: A SymPy expression, another SymPyFuncWrapper, or a numeric value. Raises: AssertionError: If the input is not a valid SymPy expression or numeric value. """ if isinstance(sympy_func, SymPyFuncWrapper): sympy_func = sympy_func.func type_ = type(sympy_func) allowed = (sympy.Expr, float, int) assert isinstance(sympy_func, allowed), ( f"Function must be from sympy, but is {type_}" ) # Convert numeric types to SymPy Number if isinstance(sympy_func, (float, int)): self._func = sympy.Number(sympy_func) else: self._func = sympy_func def __str__(self) -> str: """Return a string representation of the wrapped expression.""" return str(self._func) def __float__(self) -> float: """ Convert the expression to a float if possible. Returns: The float value of the expression. Raises: TypeError: If the expression cannot be converted to a float. """ if isinstance(self._func, (sympy.Number, float, int)): return float(self._func) result = self._func.evalf() if isinstance(result, sympy.Number): return float(result) def __repr__(self) -> str: """Return a string representation of the wrapped expression. Note: For backward compatibility, this returns the representation of the underlying SymPy expression, not the wrapper itself. """ return repr(self._func) def __call__(self, *args, **kwargs) -> "SymPyFuncWrapper": """ Substitute values for symbols in the expression. Args: *args: Positional arguments matched to free symbols in alphabetical order. **kwargs: Keyword arguments matched to free symbol names. Returns: A new SymPyFuncWrapper with the substitutions applied. Raises: ValueError: If both args and kwargs are provided. """ vars_, _ = self._prepare_call(args, kwargs) func_raw = self._func func = func_raw.subs(vars_) return SymPyFuncWrapper(func) def _prepare_call( self, args: tuple, kwargs: dict ) -> Tuple[Dict[sympy.Symbol, Any], Dict[str, Any]]: """ Prepare arguments for substitution in the expression. Args: args: Positional arguments. kwargs: Keyword arguments. Returns: A tuple of (substitution_dict, processed_kwargs). Raises: ValueError: If both args and kwargs are provided. """ # Get all free symbols in the expression symbols = list(self._func.free_symbols) free_symbols = sorted([str(s) for s in symbols]) # Check if we have both args and kwargs non_none_kwargs = {k: v for k, v in kwargs.items() if v is not None} if args and non_none_kwargs: raise ValueError("Cannot provide both positional and keyword arguments") # Map positional args to symbols if appropriate if args and len(free_symbols) == len(args): kwargs = {free_sym: arg for free_sym, arg in zip(free_symbols, args)} elif args: # remove None args = tuple(arg for arg in args if arg is not None) if args: raise ValueError( f"Expected {len(free_symbols)} or 0 positional arguments for " f"expression with {len(free_symbols)} free symbols, got {len(args)}" ) # Filter out None values kwargs = {k: v for k, v in kwargs.items() if v is not None} # Create substitution dictionary vars_ = {} for sym in symbols: sym_str = str(sym) if sym_str in kwargs: value = kwargs[sym_str] # Handle nested SymPyFuncWrapper if isinstance(value, SymPyFuncWrapper): value = value.func vars_[sym] = value return vars_, kwargs @property def func(self) -> sympy.Expr: """Return the underlying SymPy expression.""" return self._func @property def free_symbols(self) -> Set[sympy.Symbol]: """Return the set of free symbols in the expression.""" return self._func.free_symbols
[docs] def subs(self, *args, **kwargs) -> "SymPyFuncWrapper": """ Substitute values in the expression. This method directly uses SymPy's subs method. Returns: A new SymPyFuncWrapper with the substitutions applied. """ result = self._func.subs(*args, **kwargs) return SymPyFuncWrapper(result)
[docs] def diff(self, *args, **kwargs) -> "SymPyFuncWrapper": """ Differentiate the expression. This method directly uses SymPy's diff method. Returns: A new SymPyFuncWrapper with the differentiated expression. """ result = self._func.diff(*args, **kwargs) return SymPyFuncWrapper(result)
[docs] def integrate(self, *args, **kwargs) -> "SymPyFuncWrapper": """ Integrate the expression. This method directly uses SymPy's integrate method. Returns: A new SymPyFuncWrapper with the integrated expression. """ result = sympy.integrate(self._func, *args, **kwargs) return SymPyFuncWrapper(result)
[docs] def simplify(self) -> "SymPyFuncWrapper": """ Simplify the expression. Returns: A new SymPyFuncWrapper with the simplified expression. """ return SymPyFuncWrapper(sympy.simplify(self._func))
[docs] def expand(self) -> "SymPyFuncWrapper": """ Expand the expression. Returns: A new SymPyFuncWrapper with the expanded expression. """ return SymPyFuncWrapper(self._func.expand())
[docs] def factor(self) -> "SymPyFuncWrapper": """ Factor the expression. Returns: A new SymPyFuncWrapper with the factored expression. """ return SymPyFuncWrapper(sympy.factor(self._func))
[docs] def has(self, *args, **kwargs) -> bool: """ Check if the expression has a certain property. This method directly uses SymPy's has method. Returns: True if the expression has the specified property, False otherwise. """ return self._func.has(*args, **kwargs)
[docs] def to_latex(self) -> str: """ Convert the expression to LaTeX. Returns: A LaTeX string representation of the expression. """ return sympy.latex(self._func)
[docs] def evalf( self, n: Optional[int] = None ) -> Union[float, sympy.Expr, "SymPyFuncWrapper"]: """ Evaluate the expression numerically. Args: n: Optional number of significant digits. Returns: For backward compatibility: - For numeric expressions, returns a float - For expressions with symbols, returns the evalf'd SymPy expression """ if n is not None: result = self._func.evalf(n=n) else: result = self._func.evalf() if isinstance(result, sympy.Number) and not self._func.free_symbols: return float(result) # For backward compatibility, return the raw SymPy expression # if the original expression has free symbols if self._func.free_symbols: return result return result
def __eq__(self, other) -> bool: """ Check if this expression equals another expression or value. Args: other: Another SymPyFuncWrapper, SymPy expression, or numeric value. Returns: True if expressions are mathematically equal, False otherwise. """ try: if isinstance(other, SymPyFuncWrapper): other = other.func return sympy.simplify(self._func - other) == 0 except (TypeError, ValueError): return False def __ne__(self, other) -> bool: """ Check if this expression does not equal another expression or value. Args: other: Another SymPyFuncWrapper, SymPy expression, or numeric value. Returns: True if expressions are not mathematically equal, False otherwise. """ return not self == other
[docs] def isclose(self, other, tolerance: float = 1e-10) -> bool: """ Check if this expression is numerically close to another expression or value. Args: other: Another SymPyFuncWrapper, SymPy expression, or numeric value. tolerance: Maximum absolute difference allowed for equality (default: 1e-10). Returns: True if expressions are numerically close, False otherwise. """ try: self_val = float(self.evalf()) if isinstance(other, SymPyFuncWrapper): other_val = float(other.evalf()) else: try: other_val = float(other) except (TypeError, ValueError): return self == other return abs(self_val - other_val) < tolerance except (TypeError, ValueError): # Fall back to symbolic comparison if numeric comparison fails return self == other
# Arithmetic operations with proper handling of left and right operations def __add__(self, other) -> "SymPyFuncWrapper": """Add another expression or value to this expression.""" if isinstance(other, SymPyFuncWrapper): other = other.func return SymPyFuncWrapper(self.func + other) def __radd__(self, other) -> "SymPyFuncWrapper": """Add this expression to another value (right-side operation).""" return SymPyFuncWrapper(other + self.func) def __sub__(self, other) -> "SymPyFuncWrapper": """Subtract another expression or value from this expression.""" if isinstance(other, SymPyFuncWrapper): other = other.func return SymPyFuncWrapper(self.func - other) def __rsub__(self, other) -> "SymPyFuncWrapper": """Subtract this expression from another value (right-side operation).""" return SymPyFuncWrapper(other - self.func) def __mul__(self, other) -> "SymPyFuncWrapper": """Multiply this expression by another expression or value.""" if isinstance(other, SymPyFuncWrapper): other = other.func return SymPyFuncWrapper(self.func * other) def __rmul__(self, other) -> "SymPyFuncWrapper": """Multiply another value by this expression (right-side operation).""" return SymPyFuncWrapper(other * self.func) def __truediv__(self, other) -> "SymPyFuncWrapper": """Divide this expression by another expression or value.""" if isinstance(other, SymPyFuncWrapper): other = other.func return SymPyFuncWrapper(self.func / other) def __rtruediv__(self, other) -> "SymPyFuncWrapper": """Divide another value by this expression (right-side operation).""" return SymPyFuncWrapper(other / self.func) def __pow__(self, other) -> "SymPyFuncWrapper": """Raise this expression to the power of another expression or value.""" if isinstance(other, SymPyFuncWrapper): other = other.func return SymPyFuncWrapper(self.func**other) def __rpow__(self, other) -> "SymPyFuncWrapper": """Raise another value to the power of this expression (right-side operation).""" return SymPyFuncWrapper(other**self.func) def __neg__(self) -> "SymPyFuncWrapper": """Negate this expression.""" return SymPyFuncWrapper(-self.func) def __abs__(self) -> "SymPyFuncWrapper": """Return the absolute value of this expression.""" return SymPyFuncWrapper(sympy.Abs(self.func)) def __hash__(self) -> int: """Return a hash of this expression for use in dictionaries and sets.""" return hash(self.func)
[docs] def numpy(self) -> np.ndarray: """ Convert the expression to a numpy function. Returns: A numpy function that evaluates the expression with numpy inputs. Raises: ValueError: If the expression cannot be converted to a numpy function. """ try: from sympy.utilities.lambdify import lambdify symbols = sorted(self.free_symbols, key=lambda s: str(s)) if not symbols: # Constant expression return np.array(float(self.evalf())) func = lambdify(symbols, self.func, "numpy") return func except Exception as e: raise ValueError(f"Failed to convert to numpy function: {e}")
[docs] def numpy_func(self, *args): """ Create a vectorized numpy function that can be called with array inputs. When called without arguments, returns a callable function that can evaluate the expression with numpy arrays. When called with arguments, evaluates the expression directly with those arguments. Parameters ---------- *args : array_like, optional If provided, input arrays corresponding to the symbols in the expression. Returns ------- callable or numpy.ndarray If args are provided, returns the result of evaluating the expression. If no args are provided, returns a function that can be called later. Examples -------- >>> import sympy >>> from copul.wrapper.sympy_wrapper import SymPyFuncWrapper >>> x, y = sympy.symbols('x y') >>> expr = SymPyFuncWrapper(x**2 + y) >>> import numpy as np >>> >>> # Get a function and call it later >>> f = expr.numpy_func() >>> f(np.array([1, 2]), np.array([3, 4])) array([4, 8]) >>> >>> # Or evaluate directly >>> expr.numpy_func(np.array([1, 2]), np.array([3, 4])) array([4, 8]) """ # Get all free symbols in the expression, sorted by name. symbols = sorted(self.free_symbols, key=lambda s: str(s)) # Handle constant case if not symbols: constant_value = float(self.evalf()) return lambda *args: np.full( np.broadcast(*args).shape if args else (1,), constant_value ) # For Piecewise functions, use np.select instead of np.vectorize if self.func.is_Piecewise: # Extract conditions and expressions from the piecewise function conditions = [] expressions = [] for expr, cond in self.func.args: if cond is True: # Default case default_expr = expr else: # Lambdify both the condition and expression cond_func = lambdify(symbols, cond, "numpy") expr_func = lambdify(symbols, expr, "numpy") conditions.append(cond_func) expressions.append(expr_func) # Include the default case if "default_expr" in locals(): default_func = lambdify(symbols, default_expr, "numpy") else: def default_func(*args): return np.zeros(np.broadcast(*args).shape) # Return a function that evaluates the piecewise def efficient_piecewise(*values): # Check if all inputs are numpy arrays if not all(isinstance(v, np.ndarray) for v in values): values = [ np.asarray(v) if not isinstance(v, np.ndarray) else v for v in values ] # Evaluate all conditions and expressions conds = [c(*values) for c in conditions] exprs = [e(*values) for e in expressions] # Use np.select for efficient vectorized evaluation return np.select(conds, exprs, default_func(*values)) if args: if len(symbols) != len(args): raise ValueError( f"Expected {len(symbols)} arguments, got {len(args)}" ) return efficient_piecewise(*args) return efficient_piecewise # For non-piecewise functions, use the standard lambdify approach func = lambdify(symbols, self.func, "numpy") if args: if len(symbols) != len(args): raise ValueError(f"Expected {len(symbols)} arguments, got {len(args)}") return func(*args) return func