Skip to content

Commit

Permalink
ADD PyData Global 2021
Browse files Browse the repository at this point in the history
  • Loading branch information
jmschrei committed Oct 28, 2021
1 parent cd8520a commit 5916e91
Show file tree
Hide file tree
Showing 5 changed files with 394 additions and 165 deletions.
7 changes: 2 additions & 5 deletions apricot/functions/featureBased.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from numba import njit
from numba import prange
import torch

@njit('float64[:](float64[:])', fastmath=True)
def sigmoid(X):
Expand All @@ -35,7 +36,6 @@ def calculate_gains_(X, gains, current_values, idxs):

return calculate_gains_


def calculate_gains_sparse(func, dtypes, parallel, fastmath, cache):
@njit(dtypes, parallel=parallel, fastmath=fastmath, cache=cache)
def calculate_gains_sparse_(X_data, X_indices, X_indptr, gains,
Expand Down Expand Up @@ -313,10 +313,7 @@ def _initialize(self, X):
calculate_gains_ = calculate_gains_sparse if self.sparse else calculate_gains
dtypes_ = sparse_dtypes if self.sparse else dtypes

if self.optimizer in (LazyGreedy, ApproximateLazyGreedy):
self.calculate_gains_ = calculate_gains_(self.concave_func,
dtypes_, False, True, False)
elif self.optimizer in ('lazy', 'approimate-lazy'):
if self.optimizer in (LazyGreedy, ApproximateLazyGreedy, 'lazy', 'approximate-lazy'):
self.calculate_gains_ = calculate_gains_(self.concave_func,
dtypes_, False, True, False)
else:
Expand Down
32 changes: 19 additions & 13 deletions apricot/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,26 +114,29 @@ def __init__(self, function=None, random_state=None, n_jobs=None,

def select(self, X, k, sample_cost=None):
cost = 0.0
if sample_cost is None:
sample_cost = numpy.ones(X.shape[0], dtype='float64')

while cost < k:
gains = self.function._calculate_gains(X) / sample_cost[self.function.idxs]
idxs = numpy.lexsort((numpy.arange(gains.shape[0]), -gains))
gains = self.function._calculate_gains(X)

for idx in idxs:
if sample_cost is None:
idx = numpy.argmax(gains)
best_idx = self.function.idxs[idx]
if cost + sample_cost[best_idx] <= k:
break
idx_cost = 1
else:
break
idxs = numpy.lexsort((numpy.arange(gains.shape[0]), -gains / sample_cost[self.function.idxs]))
for idx in idxs:
best_idx = self.function.idxs[idx]
idx_cost = sample_cost[best_idx]
if cost + idx_cost <= k:
break
else:
break

cost += sample_cost[best_idx]
gain = gains[idx] * sample_cost[best_idx]
cost += idx_cost
gain = gains[idx]
self.function._select_next(X[best_idx], gain, best_idx)

if self.verbose == True:
self.function.pbar.update(round(sample_cost[best_idx], 1))
self.function.pbar.update(round(idx_cost, 1))


class LazyGreedy(BaseOptimizer):
Expand Down Expand Up @@ -226,6 +229,7 @@ def select(self, X, k, sample_cost=None):
return

prev_gain, idx = self.pq.pop()
#prev_gain, idx = self.pq.peek()
prev_gain = -prev_gain

if cost + sample_cost[idx] > k:
Expand All @@ -236,8 +240,10 @@ def select(self, X, k, sample_cost=None):

idxs = numpy.array([idx])
gain = self.function._calculate_gains(X, idxs)[0] / sample_cost[idx]
self.pq.add(idx, -gain)

#self.pq.swap(idx, -gain)
self.pq.add(idx, -gain)

if gain > best_gain:
best_gain = gain
best_idx = idx
Expand Down
75 changes: 46 additions & 29 deletions apricot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from heapq import heappush
from heapq import heappop
from heapq import heapify
from heapq import heapreplace

from scipy.sparse import csr_matrix

Expand Down Expand Up @@ -49,14 +50,12 @@ class PriorityQueue(object):

def __init__(self, items=None, weights=None):
self.counter = itertools.count()
self.lookup = {}
self.pq = []

if items is not None and weights is not None:
for item, weight in zip(items, weights):
entry = [weight, next(self.counter), item]
self.pq.append(entry)
self.lookup[item] = entry

heapify(self.pq)

Expand All @@ -83,38 +82,36 @@ def add(self, item, weight):
None
"""

if item in self.lookup:
self.remove(item)
#if item in self.lookup:
# self.remove(item)

entry = [weight, next(self.counter), item]
self.lookup[item] = entry
heappush(self.pq, entry)

def pop(self):
"""Pop the highest priority element from the queue. Runtime is O(log n).
def remove(self, item):
"""Remove an element from the queue.
This will remove the highest priority element from the queue. If there
are no elements left in the queue it will raise an error.
This is not popping the highest priority item, rather it will remove
an element that is present in the queue. If one attempts to remove an
item that is not present in the queue the function will error.
Parameters
----------
item : object
The object to be removed from the queue.
None
Returns
-------
None
weight : double
The weight of the element as passed in in the `add` method
item : object
The item that was passed in in the `add` method
"""

entry = self.lookup.pop(item)
entry[-1] = "DELETED"

def pop(self):
"""Pop the highest priority element from the queue. Runtime is O(log n).
weight, _, item = heappop(self.pq)
return weight, item

This will remove the highest priority element from the queue. If there
are no elements left in the queue it will raise an error.
def peek(self):
"""Peek at the first element in the priority queue.
Parameters
----------
Expand All @@ -126,16 +123,36 @@ def pop(self):
The weight of the element as passed in in the `add` method
item : object
The item that was passed in during the `add` method
The item that was passed in in the `add` method
"""

while self.pq:
weight, _, item = heappop(self.pq)
if item != "DELETED":
del self.lookup[item]
return weight, item

raise KeyError("No elements left in the priority queue.")
return self.pq[0][0], self.pq[0][2]

def swap(self, item, weight):
"""An efficient way to pop the first element and add a new element.
This is useful in our context because it allows us to pop the smallest
element and to add a new element, i.e., remove an element and re-add it
with the updated gain.
Parameters
----------
item : object
The object to be encoded.
weight : double
The priority of the item. The lower the weight the higher the
priority when items get dequeued. If a higher weight is supposed
to correspond to a higher priority, consider reversing the sign of
the weight.
Returns
-------
None
"""

entry = [weight, next(self.counter), item]
heapreplace(self.pq, entry)

def check_random_state(seed):
"""Turn seed into a np.random.RandomState instance.
Expand Down
Binary file added slides/apricot PyData Global 2021.pdf
Binary file not shown.
Loading

0 comments on commit 5916e91

Please sign in to comment.