Last active
June 20, 2023 13:55
-
-
Save fcharras/ce1f1df7d15675268827e1fb9b65265b to your computer and use it in GitHub Desktop.
k-means lloyd implementation with pytorch (not fused)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import math | |
import numpy as np | |
def kmeans_single( | |
# NB: best performance might depend on the layout for `X` and `centroids` | |
# TODO: benchmark and warns or error out if the layout is not adapted | |
X, # (n_samples, n_features) | |
sample_weight, # (n_samples,) | |
centroids, # (n_clusters, n_features) | |
# NB: centroids data is ovveridden during | |
# compute | |
max_iter, # int | |
tol, # float | |
verbose, # bool | |
max_compute_buffer_bytes=1073741824, # int (default 1 GiB) | |
): | |
n_samples, n_features = X.shape | |
n_clusters = centroids.shape[0] | |
compute_dtype = X[-1, -1].cpu().numpy().dtype.type | |
compute_dtype_itemsize = np.dtype(compute_dtype).itemsize | |
# The computation of nearest centroids will be batched and the size of each batch | |
# is set so that the size of the buffer of pairwise distances computed for this | |
# batch do not exceed `maximum_comnpute_buffer_size` | |
( | |
assignment_batch_size, | |
assignment_n_batches, | |
assignment_n_full_batches, | |
assignment_last_batch_size | |
) = _get_batch_properties( | |
expected_bytes_per_sample = n_clusters * compute_dtype_itemsize, | |
max_compute_buffer_bytes = max_compute_buffer_bytes, | |
dataset_n_samples = n_samples | |
) | |
# Batching the update of the centroids is also necessary to support non-uniform | |
# sample weights. | |
( | |
update_batch_size, | |
update_n_batches, | |
update_n_full_batches, | |
update_last_batch_size | |
) = _get_batch_properties( | |
expected_bytes_per_sample = n_features * compute_dtype_itemsize, | |
max_compute_buffer_bytes = max_compute_buffer_bytes, | |
dataset_n_samples = n_samples | |
) | |
# Pre-allocate buffers that will be reused accross iterations (rather than re- | |
# allocated) | |
new_centroids = torch.zeros_like(centroids) # TODO: test memory layouts ? | |
weight_in_clusters = torch.zeros(n_clusters, dtype=X.dtype, device=X.device) | |
new_weight_in_clusters = torch.zeros_like(weight_in_clusters) | |
# Those buffers that will store centroid assignments for each sample are | |
# over-allocated with `n_clusters` extra values ranging for 0 to `n_clusters`, | |
# that are used to detect empty clusters later on using torch.unique | |
assignments_idx_extended = torch.empty( | |
(n_samples + n_clusters, 1), dtype=torch.int64, device=X.device | |
) | |
assignments_idx_extended[n_samples:] = torch.arange( | |
n_clusters, dtype=assignments_idx_extended.dtype, device=X.device | |
).unsqueeze(1) | |
assignments_idx = assignments_idx_extended[:n_samples] | |
new_assignments_idx_extended = torch.empty_like(assignments_idx_extended) | |
new_assignments_idx_extended[n_samples:] = assignments_idx_extended[n_samples:] | |
new_assignments_idx = new_assignments_idx_extended[:n_samples] | |
dist_to_nearest_centroid = torch.empty( | |
(n_samples, 1), dtype=X.dtype, device=X.device | |
) | |
dist_to_nearest_centroid_sqz = dist_to_nearest_centroid.squeeze(1) | |
n_iteration = 0 | |
strict_convergence = False | |
centroid_shifts_sum = torch.inf | |
while (n_iteration < max_iter) and ( | |
centroid_shifts_sum > tol | |
): | |
# NB: current implementation of _min_over_pairwise_distance is underwhelming | |
# because for each batch it materializes in memory the pairwise distance matrix, | |
# before searching the closest centroid. The IO from writing and reading from | |
# global memory becomes the bottleneck. It can be about 3 times faster (or | |
# more ?) if the pairwise distance and the min lookup are fused together | |
# in a way that global memory is not used anymore. That would require a custom | |
# low level implementation (e.g using triton directly), `torch.compiler` | |
# doesn't seem to support automatically fusing `torch.cdist` and `torch.min`. | |
_min_over_pairwise_distance( | |
X, | |
centroids, | |
assignment_n_batches, | |
assignment_n_full_batches, | |
assignment_batch_size, | |
assignment_last_batch_size, | |
# OUT | |
dist_to_nearest_centroid, | |
new_assignments_idx, | |
) | |
# ???: should we pass `sorted=False` ? | |
unique_clusters, counts = torch.unique( | |
new_assignments_idx_extended, return_counts=True | |
) | |
empty_clusters_list = unique_clusters[counts == 1] | |
new_centroids[:] = 0 | |
new_weight_in_clusters[:] = 0 | |
# relocate empty clusters if such clusters are detected | |
if (n_empty_clusters := len(empty_clusters_list)) > 0: | |
print("relocation event") | |
# ???: should we pass `sorted=False` ? | |
samples_far_from_center = torch.topk( | |
dist_to_nearest_centroid_sqz, n_empty_clusters | |
).indices | |
new_centroids[empty_clusters_list] = X[samples_far_from_center] | |
new_assignments_idx[ | |
samples_far_from_center | |
] = empty_clusters_list.unsqueeze(1) | |
dist_to_nearest_centroid[samples_far_from_center] = 0 | |
if verbose: | |
inertia = ( | |
sample_weight | |
* dist_to_nearest_centroid_sqz | |
* dist_to_nearest_centroid_sqz | |
).sum().item() | |
print(f"Iteration {n_iteration}, inertia {inertia:5.3e}") | |
# update centers | |
# NB: (same comment than for `_min_over_pairwise_distance`) | |
# Multipliying with weights and then using `scatter_add_` could be fused | |
# together, yet again with a x2 - x3 speedup. | |
batch_start_idx = batch_end_idx = 0 | |
for batch_idx in range(update_n_batches): | |
if batch_idx == update_n_full_batches: | |
batch_end_idx += update_last_batch_size | |
else: | |
batch_end_idx += update_batch_size | |
batch_slice = slice(batch_start_idx, batch_end_idx) | |
X_weighted = X[batch_slice] * sample_weight[batch_slice].unsqueeze(1) | |
new_centroids.scatter_add_( | |
dim=0, | |
# NB: expand does not allocate memory, it's like a "repeated view" | |
index=new_assignments_idx[batch_slice].expand(-1, n_features), | |
src=X_weighted | |
) | |
del X_weighted | |
# HACK: force synchronization to avoid memory overflow | |
# Similar to torch.cuda.synchronize(X.device) but with device | |
# interoperability for a negligible cost. | |
new_centroids[-1, -1].item() | |
batch_start_idx += update_batch_size | |
new_weight_in_clusters.scatter_add_( | |
dim=0, index=new_assignments_idx.squeeze(), src=sample_weight | |
) | |
new_centroids /= new_weight_in_clusters.unsqueeze(1) | |
centroids, new_centroids = new_centroids, centroids | |
assignments_idx, new_assignments_idx = new_assignments_idx, assignments_idx | |
assignments_idx_extended, new_assignments_idx_extended = ( | |
new_assignments_idx_extended, assignments_idx_extended | |
) | |
n_iteration += 1 | |
if (n_iteration > 1) and ( | |
strict_convergence := bool( | |
(assignments_idx == new_assignments_idx).all()) | |
): | |
break | |
new_centroids -= centroids | |
new_centroids *= new_centroids | |
centroid_shifts_sum = new_centroids.sum().item() | |
if verbose: | |
converged_at = n_iteration - 1 | |
# NB: possible if tol = 0 | |
if strict_convergence or (centroid_shifts_sum == 0): | |
print(f"Converged at iteration {converged_at}: strict convergence.") | |
elif centroid_shifts_sum <= tol: | |
print( | |
f"Converged at iteration {converged_at}: center shift " | |
f"{centroid_shifts_sum} within tolerance {tol}." | |
) | |
# TODO: if strict_convergence: no need to do that | |
_min_over_pairwise_distance( | |
X, | |
centroids, | |
assignment_n_batches, | |
assignment_n_full_batches, | |
assignment_batch_size, | |
assignment_last_batch_size, | |
# OUT | |
dist_to_nearest_centroid, | |
assignments_idx, | |
) | |
inertia = ( | |
sample_weight | |
* dist_to_nearest_centroid_sqz | |
* dist_to_nearest_centroid_sqz | |
).sum().item() | |
return assignments_idx.squeeze(), inertia, centroids, n_iteration | |
def _get_batch_properties( | |
expected_bytes_per_sample, | |
max_compute_buffer_bytes, | |
dataset_n_samples | |
): | |
batch_size = ( | |
max_compute_buffer_bytes / | |
expected_bytes_per_sample | |
) | |
if batch_size < 1: | |
raise RuntimeError("Buffer size is too small") | |
batch_size = min(math.floor(batch_size), dataset_n_samples) | |
n_batches = math.ceil(dataset_n_samples / batch_size) | |
n_full_batches = n_batches - 1 | |
last_batch_size = ((dataset_n_samples - 1) % batch_size) + 1 | |
return batch_size, n_batches, n_full_batches, last_batch_size | |
def _min_over_pairwise_distance( | |
X, # IN (n_samples, n_features) | |
centroids, # IN (n_clusters, n_feautres) | |
n_batches, # PARAM int | |
n_full_batches, # PARAM int | |
batch_size, # PARAM int | |
last_batch_size, # PARAM int | |
dist_to_nearest_centroid, # OUT (n_samples, n_clusters) | |
assignments_idx, # OUT (n_samples,) | |
): | |
"""The result is returned in `dist_to_nearest_centroid` and `assignments_idx` | |
arrays that are modified inplace""" | |
# TODO: slice here so that pairwise_distance has a max size of 1GB | |
batch_start_idx = batch_end_idx = 0 | |
for batch_idx in range(n_batches): | |
if batch_idx == n_full_batches: | |
batch_end_idx += last_batch_size | |
else: | |
batch_end_idx += batch_size | |
batch_slice = slice(batch_start_idx, batch_end_idx) | |
pairwise_distances = torch.cdist(X[batch_slice], centroids) | |
torch.min( | |
pairwise_distances, | |
axis=1, | |
keepdims=True, | |
out=( | |
dist_to_nearest_centroid[batch_slice], | |
assignments_idx[batch_slice] | |
) | |
) | |
del pairwise_distances | |
# HACK: force synchronization to avoid memory overflow | |
# Similar to torch.cuda.synchronize(X.device) but with device interoperability | |
# for a negligible cost. | |
assignments_idx[-1, -1].item() | |
batch_start_idx += batch_size | |
if __name__ == "__main__": | |
n_samples = 500000 # common sizes: 50000, 50000000, 50000000 | |
n_features = 14 | |
n_clusters = 127 | |
max_iter = 20 | |
tol = 0 | |
verbose = True | |
device = "cuda" | |
dtype = torch.float32 | |
seed = 123 | |
rng = torch.Generator(device=device).manual_seed(543212345) | |
X = torch.rand( | |
n_samples, n_features, generator=rng, dtype=dtype, device=device | |
) | |
centroids = torch.rand( | |
n_clusters, n_features, generator=rng, dtype=dtype, device=device | |
) | |
sample_weight = torch.rand(n_samples, generator=rng, dtype=dtype, device=device) | |
kmeans_single( | |
X, | |
sample_weight, | |
centroids, | |
max_iter, | |
tol, | |
verbose, | |
max_compute_buffer_bytes=1073741824, | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment