Source code for copul.wrapper.cdi_wrapper

import sympy
from copul.wrapper.conditional_wrapper import ConditionalWrapper
from copul.wrapper.sympy_wrapper import SymPyFuncWrapper


[docs] class CDiWrapper(ConditionalWrapper): """ General wrapper for the conditional distribution when conditioning on the i-th variable. Handles multivariate case (u1, u2, ..., un). Boundary conditions: - If ui = 0, returns 0 - If ui = 1, returns 1 (representing the unconditional distribution of the remaining variables) - If any other variable is 0 or 1, handles according to conditional distribution rules """ def __init__(self, func, i): """ Initialize with a sympy expression and the index for conditioning. Parameters ---------- func : sympy.Expr The symbolic expression representing the conditional distribution. i : int The index of the variable being conditioned on (1-based indexing). """ self._index = i # Store the index separately for creating new instances super().__init__(func, condition_index=i) def _check_boundary_conditions(self, u_symbols, vars_dict, kwargs): """ Check boundary conditions for conditioning on the i-th variable. Parameters ---------- u_symbols : dict Dictionary of u-symbols in the expression. vars_dict : dict Dictionary of variable substitutions. kwargs : dict Keyword arguments for substitution. Returns ------- SymPyFuncWrapper or None A wrapper with the boundary value, or None if no boundary condition applies. """ i = self.condition_index # Check all other variables first (priority for other variables being 0 or 1) # get max_dim from u_symbols dimensions = [ int(k[1:]) for k in u_symbols.keys() if k.startswith("u") and len(k) > 1 ] max_dim = max(dimensions, default=0) for j in range(1, max_dim): # Assume reasonable max dimension of 10 if j == i: continue # Skip the conditioning variable target_sym = f"u{j}" if target_sym in u_symbols or target_sym in kwargs: uj_sym = u_symbols.get(target_sym, sympy.symbols(target_sym)) uj_val = None if target_sym in kwargs: uj_val = kwargs[target_sym] elif uj_sym in vars_dict: uj_val = vars_dict[uj_sym] if uj_val == 0: return SymPyFuncWrapper(sympy.S.Zero) elif uj_val == 1: # When another variable is 1, the conditional depends only on remaining variables # For simplicity, we'll return the conditioning variable itself or 1 target_cond_sym = f"u{i}" if target_cond_sym in u_symbols: ui_sym = u_symbols[target_cond_sym] if ui_sym in vars_dict: return SymPyFuncWrapper(vars_dict[ui_sym]) elif target_cond_sym in kwargs: return SymPyFuncWrapper(kwargs[target_cond_sym]) return SymPyFuncWrapper(sympy.symbols(f"u{i}")) # Now check the conditioning variable target_sym = f"u{i}" if target_sym in u_symbols or target_sym in kwargs: ui_sym = u_symbols.get(target_sym, sympy.symbols(target_sym)) ui_val = None if target_sym in kwargs: ui_val = kwargs[target_sym] elif ui_sym in vars_dict: ui_val = vars_dict[ui_sym] if ui_val == 0: return SymPyFuncWrapper(sympy.S.Zero) elif ui_val == 1: return SymPyFuncWrapper(sympy.S.One) # Handle the bivariate special case if i == 2 and "u" in u_symbols and "v" in u_symbols: # Check u (first variable) u_sym = u_symbols["u"] u_val = None if "u" in kwargs: u_val = kwargs["u"] elif u_sym in vars_dict: u_val = vars_dict[u_sym] if u_val == 0: return SymPyFuncWrapper(sympy.S.Zero) elif u_val == 1: return SymPyFuncWrapper(sympy.S.One) elif i == 1 and "u" in u_symbols and "v" in u_symbols: # Check v (second variable) v_sym = u_symbols["v"] v_val = None if "v" in kwargs: v_val = kwargs["v"] elif v_sym in vars_dict: v_val = vars_dict[v_sym] if v_val == 0: return SymPyFuncWrapper(sympy.S.Zero) elif v_val == 1: return SymPyFuncWrapper(sympy.S.One) return None def __call__(self, *args, **kwargs): """ Evaluate the conditional distribution with the given arguments. Parameters ---------- *args, **kwargs Arguments to substitute into the expression. Returns ------- CDiWrapper or SymPyFuncWrapper A new wrapper with the substituted expression. """ # Get all u-symbols in the expression (names present in the sympy expr) u_symbols = self._get_u_symbols() # --- NEW: eager positional mapping for the common bivariate signature (u, v) # If user passed 2 positional args and no kwargs, interpret as (u, v) # regardless of how many free symbols are in the underlying expression. if args and not kwargs: if len(args) == 2: # Standardize bivariate call signature kwargs = {"u": args[0], "v": args[1]} args = () elif len(args) == 1: # If expression only uses one of {'u','v'}, map the single arg to that one. # Prefer 'u' if present, else 'v'. Fall back to 'u1' if neither present. if "u" in u_symbols: kwargs = {"u": args[0]} elif "v" in u_symbols: kwargs = {"v": args[0]} else: # Multivariate fallback: map to the conditioning variable ui i = self.condition_index kwargs = {f"u{i}": args[0]} args = () # Process arguments to create variable substitutions vars_, kwargs = self._prepare_call(args, kwargs) # Check for boundary conditions boundary_result = self._check_boundary_conditions(u_symbols, vars_, kwargs) if boundary_result is not None: return boundary_result # Apply substitutions func = self._func.subs(vars_) # Return a new wrapper of the same type with the same index return CDiWrapper(func, self._index)