Source code for copul.wrapper.inv_gen_wrapper

import numpy as np
import sympy

from copul.wrapper.sympy_wrapper import SymPyFuncWrapper


[docs] class InvGenWrapper(SymPyFuncWrapper): def __init__(self, expr, y_symbol, copula_instance): super().__init__(expr) self.y_symbol = y_symbol self.copula = copula_instance # Cache commonly needed values self.theta_val = getattr(self.copula, "theta", None) self.generator_at_0 = getattr(self.copula, "_generator_at_0", None) def __call__(self, *args, **kwargs): # Handle edge cases explicitly if "y" in kwargs: y_val = kwargs["y"] # Case 1: y = 0 if y_val == 0 or y_val == 0.0: return InvGenWrapper(sympy.Float(1.0), self.y_symbol, self.copula) # Case 2: y = infinity elif y_val == sympy.oo: return InvGenWrapper(sympy.Float(0.0), self.y_symbol, self.copula) # Specific case for Nelsen11 with log(2) elif ( str(y_val) == "log(2)" and self.copula.__class__.__name__ == "Nelsen11" ): return InvGenWrapper(sympy.Float(0.0), self.y_symbol, self.copula) # Case 3: Nelsen11 special case at y = _generator_at_0 elif ( hasattr(self.copula, "_generator_at_0") and not isinstance(y_val, sympy.Expr) and self.copula._generator_at_0 != sympy.oo and y_val == self.copula._generator_at_0 ): return InvGenWrapper(sympy.Float(0.0), self.y_symbol, self.copula) # Case 4: For any y > _generator_at_0 (if defined) elif ( hasattr(self.copula, "_generator_at_0") and not isinstance(y_val, sympy.Expr) and self.copula._generator_at_0 != sympy.oo and y_val > self.copula._generator_at_0 ): return InvGenWrapper(sympy.Float(0.0), self.y_symbol, self.copula) # Get result from parent call result = super().__call__(*args, **kwargs) # Wrap the result in InvGenWrapper to preserve special handling if isinstance(result, SymPyFuncWrapper) and not isinstance( result, InvGenWrapper ): return InvGenWrapper(result.func, self.y_symbol, self.copula) return result
[docs] def subs(self, *args, **kwargs): # Special handling in subs method if len(args) >= 2 and args[0] == self.y_symbol: y_val = args[1] # Specific case for Nelsen11 with log(2) if str(y_val) == "log(2)" and self.copula.__class__.__name__ == "Nelsen11": return sympy.Float(0.0) # Only do direct comparisons for non-symbolic values if not isinstance(y_val, sympy.Expr): # Case 1: y = 0 if y_val == 0 or y_val == 0.0: return sympy.Float(1.0) # Case 2: y = infinity elif y_val == sympy.oo: return sympy.Float(0.0) # Case 3: Nelsen11 special case at y = _generator_at_0 elif ( hasattr(self.copula, "_generator_at_0") and self.copula._generator_at_0 != sympy.oo and y_val == self.copula._generator_at_0 ): return sympy.Float(0.0) # Case 4: For any y > _generator_at_0 (if defined) elif ( hasattr(self.copula, "_generator_at_0") and self.copula._generator_at_0 != sympy.oo and y_val > self.copula._generator_at_0 ): return sympy.Float(0.0) # For symbolic expressions, we can only safely check equality with 0 and oo else: # Check for equality with 0 try: if y_val.is_zero: return sympy.Float(1.0) except (AttributeError, TypeError): pass # Check for equality with infinity try: if y_val == sympy.oo: return sympy.Float(0.0) except (TypeError, ValueError): pass # For other substitutions, use parent method return super().subs(*args, **kwargs)
[docs] def numpy_func(self): # Get the base function base_func = super().numpy_func() # Get critical value if available generator_at_0 = getattr(self.copula, "_generator_at_0", None) if generator_at_0 == sympy.oo: generator_at_0 = None # Don't use infinity as a special case # Return a wrapper function that handles edge cases def inv_gen_with_edge_cases(y): # Convert to numpy array y_arr = np.asarray(y) result = np.empty_like(y_arr, dtype=float) # Handle edge cases zero_mask = np.isclose(y_arr, 0) inf_mask = np.isinf(y_arr) # Initialize regular mask assuming no critical value regular_mask = ~(zero_mask | inf_mask) # Handle critical value if it exists and is finite if generator_at_0 is not None and np.isfinite(float(generator_at_0)): try: critical_mask = np.isclose(y_arr, float(generator_at_0)) | ( y_arr > float(generator_at_0) ) regular_mask = ~(zero_mask | inf_mask | critical_mask) # Set values for critical points result[critical_mask] = 0.0 except (TypeError, ValueError): # Skip critical value handling if comparison fails pass # Set values for standard edge cases result[zero_mask] = 1.0 result[inf_mask] = 0.0 # Apply normal function to regular values if np.any(regular_mask): try: result[regular_mask] = base_func(y_arr[regular_mask]) except Exception: # Fallback to scalar evaluation if vectorized fails for i, idx in enumerate(np.where(regular_mask)[0]): try: result[idx] = base_func(y_arr[idx]) except Exception: result[idx] = 0.0 # Default if all else fails # Return scalar or array based on input type return float(result) if np.isscalar(y) else result return inv_gen_with_edge_cases
def __float__(self): """Override to handle special cases when converting to float""" # Specific case for Nelsen11 if self.copula.__class__.__name__ == "Nelsen11": expr_str = str(self._func) if "0**" in expr_str and "/theta" in expr_str: return 0.0 # Special handling for expressions like 0**(1/theta) expr_str = str(self._func) if "0**" in expr_str and "/theta" in expr_str: return 0.0 # Try standard conversion try: return super().__float__() except (TypeError, ValueError): # If we can't convert, and it's clearly an edge case, return appropriate value if "oo" in expr_str or "inf" in expr_str.lower(): return 0.0 # Check if there are any clear indicators this is a zero value if ( "0**" in expr_str or "(2 - exp(log(2)))" in expr_str or self.copula.__class__.__name__ == "Nelsen11" and "log(2)" in expr_str ): return 0.0 # Last resort fallback return 0.0