Skip to content

Commit

Permalink
Merge pull request #24 from TutteInstitute/sample_weights
Browse files Browse the repository at this point in the history
Add support for sample weights
  • Loading branch information
lmcinnes authored Oct 1, 2024
2 parents 163e167 + 30e1c97 commit caf94c6
Show file tree
Hide file tree
Showing 4 changed files with 182 additions and 54 deletions.
38 changes: 30 additions & 8 deletions fast_hdbscan/boruvka.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,22 +247,44 @@ def initialize_boruvka_from_knn(knn_indices, knn_distances, core_distances, disj
return result[:result_idx]


def parallel_boruvka(tree, min_samples=10):
@numba.njit(parallel=True)
def sample_weight_core_distance(distances, neighbors, sample_weights, min_samples):
core_distances = np.zeros(distances.shape[0], dtype=np.float32)
for i in numba.prange(distances.shape[0]):
total_weight = 0.0
j = 0
while total_weight < min_samples and j < neighbors.shape[1]:
total_weight += sample_weights[neighbors[i, j]]
j += 1

core_distances[i] = distances[i, j - 1]

return core_distances

def parallel_boruvka(tree, min_samples=10, sample_weights=None):
components_disjoint_set = ds_rank_create(tree.data.shape[0])
point_components = np.arange(tree.data.shape[0])
node_components = np.full(tree.node_data.shape[0], -1)
n_components = point_components.shape[0]

if min_samples > 1:
distances, neighbors = parallel_tree_query(tree, tree.data, k=min_samples + 1, output_rdist=True)
core_distances = distances.T[-1]
if sample_weights is not None:
mean_sample_weight = np.mean(sample_weights)
expected_neighbors = min_samples / mean_sample_weight
distances, neighbors = parallel_tree_query(tree, tree.data, k=int(2 * expected_neighbors))
core_distances = sample_weight_core_distance(distances, neighbors, sample_weights, min_samples)
edges = initialize_boruvka_from_knn(neighbors, distances, core_distances, components_disjoint_set)
update_component_vectors(tree, components_disjoint_set, node_components, point_components)
else:
core_distances = np.zeros(tree.data.shape[0], dtype=np.float32)
distances, neighbors = parallel_tree_query(tree, tree.data, k=2)
edges = initialize_boruvka_from_knn(neighbors, distances, core_distances, components_disjoint_set)
update_component_vectors(tree, components_disjoint_set, node_components, point_components)
if min_samples > 1:
distances, neighbors = parallel_tree_query(tree, tree.data, k=min_samples + 1, output_rdist=True)
core_distances = distances.T[-1]
edges = initialize_boruvka_from_knn(neighbors, distances, core_distances, components_disjoint_set)
update_component_vectors(tree, components_disjoint_set, node_components, point_components)
else:
core_distances = np.zeros(tree.data.shape[0], dtype=np.float32)
distances, neighbors = parallel_tree_query(tree, tree.data, k=2)
edges = initialize_boruvka_from_knn(neighbors, distances, core_distances, components_disjoint_set)
update_component_vectors(tree, components_disjoint_set, node_components, point_components)

while n_components > 1:
candidate_distances, candidate_indices = boruvka_tree_query(tree, node_components, point_components,
Expand Down
62 changes: 54 additions & 8 deletions fast_hdbscan/cluster_trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,16 @@ def create_linkage_merge_data(base_size):
return LinkageMergeData(parent, size, next_parent)


@numba.njit()
def create_linkage_merge_data_w_sample_weights(sample_weights):
base_size = sample_weights.shape[0]
parent = np.full(2 * base_size - 1, -1, dtype=np.intp)
size = np.concatenate((sample_weights, np.zeros(base_size - 1, dtype=np.float32)))
next_parent = np.array([base_size], dtype=np.intp)

return LinkageMergeData(parent, size, next_parent)


@numba.njit()
def linkage_merge_find(linkage_merge, node):
relabel = node
Expand Down Expand Up @@ -78,6 +88,36 @@ def mst_to_linkage_tree(sorted_mst):
return result


@numba.njit()
def mst_to_linkage_tree_w_sample_weights(sorted_mst, sample_weights):
result = np.empty((sorted_mst.shape[0], sorted_mst.shape[1] + 1))

linkage_merge = create_linkage_merge_data_w_sample_weights(sample_weights)

for index in range(sorted_mst.shape[0]):

left = np.intp(sorted_mst[index, 0])
right = np.intp(sorted_mst[index, 1])
delta = sorted_mst[index, 2]

left_component = linkage_merge_find(linkage_merge, left)
right_component = linkage_merge_find(linkage_merge, right)

if left_component > right_component:
result[index][0] = left_component
result[index][1] = right_component
else:
result[index][1] = left_component
result[index][0] = right_component

result[index][2] = delta
result[index][3] = linkage_merge.size[left_component] + linkage_merge.size[right_component]

linkage_merge_join(linkage_merge, left_component, right_component)

return result


@numba.njit()
def bfs_from_hierarchy(hierarchy, bfs_root, num_points):
to_process = [bfs_root]
Expand Down Expand Up @@ -121,7 +161,7 @@ def eliminate_branch(branch_node, parent_node, lambda_value, parents, children,


@numba.njit(fastmath=True)
def condense_tree(hierarchy, min_cluster_size=10):
def condense_tree(hierarchy, min_cluster_size=10, sample_weights=None):
root = 2 * hierarchy.shape[0]
num_points = hierarchy.shape[0] + 1
next_label = num_points + 1
Expand All @@ -134,10 +174,13 @@ def condense_tree(hierarchy, min_cluster_size=10):
parents = np.ones(root, dtype=np.int64)
children = np.empty(root, dtype=np.int64)
lambdas = np.empty(root, dtype=np.float32)
sizes = np.ones(root, dtype=np.int64)
sizes = np.ones(root, dtype=np.float32)

ignore = np.zeros(root + 1, dtype=np.bool_) # 'bool' is no longer an attribute of 'numpy'

if sample_weights is None:
sample_weights = np.ones(num_points, dtype=np.float32)

idx = 0

for node in node_list:
Expand All @@ -153,8 +196,8 @@ def condense_tree(hierarchy, min_cluster_size=10):
else:
lambda_value = np.inf

left_count = np.int64(hierarchy[left - num_points, 3]) if left >= num_points else 1
right_count = np.int64(hierarchy[right - num_points, 3]) if right >= num_points else 1
left_count = np.float32(hierarchy[left - num_points, 3]) if left >= num_points else sample_weights[left]
right_count = np.float32(hierarchy[right - num_points, 3]) if right >= num_points else sample_weights[right]

# The logic here is in a strange order, but it has non-trivial performance gains ...
# The most common case by far is a singleton on the left; and cluster on the right take care of this separately
Expand Down Expand Up @@ -391,7 +434,7 @@ def extract_clusters_bcubed(condensed_tree, cluster_tree, label_indices, allow_v

@numba.njit()
def score_condensed_tree_nodes(condensed_tree):
result = {0: 0.0 for i in range(0)}
result = {0: np.float32(0.0) for i in range(0)}

for i in range(condensed_tree.parent.shape[0]):
parent = condensed_tree.parent[i]
Expand Down Expand Up @@ -559,13 +602,16 @@ def get_cluster_labelling_at_cut(linkage_tree, cut, min_cluster_size):
def get_cluster_label_vector(
tree,
clusters,
cluster_selection_epsilon
cluster_selection_epsilon,
n_samples,
):
if len(tree.parent) == 0:
return np.full(n_samples, -1, dtype=np.intp)
root_cluster = tree.parent.min()
result = np.empty(root_cluster, dtype=np.intp)
result = np.full(n_samples, -1, dtype=np.intp)
cluster_label_map = {c: n for n, c in enumerate(np.sort(clusters))}

disjoint_set = ds_rank_create(tree.parent.max() + 1)
disjoint_set = ds_rank_create(max(tree.parent.max() + 1, tree.child.max() + 1))
clusters = set(clusters)

for n in range(tree.parent.shape[0]):
Expand Down
Loading

0 comments on commit caf94c6

Please sign in to comment.