Skip to content

Commit

Permalink
replace interp2d with RegularGridInterpolator
Browse files Browse the repository at this point in the history
  • Loading branch information
AaronDJohnson committed Dec 8, 2024
1 parent 2702157 commit a69c8d0
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions enterprise_extensions/empirical_distr.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
sklearn_available=True
except ModuleNotFoundError:
sklearn_available=False
from scipy.interpolate import interp1d, interp2d
from scipy.interpolate import interp1d, interp2d, RegularGridInterpolator

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -175,7 +175,8 @@ def __init__(self, param_names, samples, minvals=None, maxvals=None, bandwidth=0
self._Nbins = [yvals.size for ii in range(xvals.size)]
scores = np.array([self.kde.score(np.array([xvals[ii], yvals[jj]]).reshape((1, 2))) for ii in range(xvals.size) for jj in range(yvals.size)])
# interpolate within prior
self._logpdf = interp2d(xvals, yvals, scores, kind='linear', fill_value=-1000)
self._logpdf = RegularGridInterpolator((xvals, yvals), scores, method='linear', bounds_error=False, fill_value=-1000)


def draw(self):
params = self.kde.sample(1).T
Expand Down

0 comments on commit a69c8d0

Please sign in to comment.