Last active
January 23, 2025 07:12
-
-
Save cutecutecat/4e24fe74a83bd9c96b04d78938c4f505 to your computer and use it in GitHub Desktop.
kmeans toolkit
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import datetime | |
import logging | |
import math | |
import psutil | |
import numpy as np | |
import psycopg | |
from tqdm import tqdm | |
from pgvector.psycopg import register_vector | |
GB_TO_B = 1024**3 | |
MAX_POINTS_PER_CLUSTER = 256 | |
logger = logging.getLogger() | |
logger.setLevel(logging.INFO) | |
handler = logging.StreamHandler() | |
logger.addHandler(handler) | |
class CustomFormatter(logging.Formatter): | |
def formatTime(self, record, datefmt=None): | |
t = datetime.datetime.fromtimestamp(record.created) | |
return t.strftime("%H:%M:%S") | |
def format(self, record): | |
mem = psutil.virtual_memory().used / GB_TO_B | |
record.msg = f"Mem: {mem:.1f}GB {record.msg}" | |
return super().format(record) | |
formatter = CustomFormatter("%(asctime)s - %(levelname)s - %(message)s") | |
handler.setFormatter(formatter) | |
def build_arg_parse(): | |
parser = argparse.ArgumentParser(description="Dump embeddings to a local file") | |
parser.add_argument("-n", "--name", help="table name", required=True) | |
parser.add_argument( | |
"-u", | |
"--db-url", | |
help="Database URL", | |
default="postgresql://postgres:123@localhost:5432/postgres", | |
) | |
parser.add_argument("-o", "--output", help="Output filepath", required=True) | |
parser.add_argument("-c", "--column", help="Column name", default="embedding") | |
parser.add_argument("-d", "--dim", type=int, required=True) | |
parser.add_argument("--lists", help="Number of centroids", type=int, required=True) | |
return parser | |
def create_connection(db_url): | |
keepalive_kwargs = { | |
"keepalives": 1, | |
"keepalives_idle": 30, | |
"keepalives_interval": 5, | |
"keepalives_count": 5, | |
} | |
conn = psycopg.connect( | |
conninfo=db_url, | |
autocommit=True, | |
**keepalive_kwargs, | |
) | |
# Disable audit to avoid large WAL | |
conn.set_read_only(True) | |
register_vector(conn) | |
return conn | |
def extract_rows(conn, name): | |
# 100M+: estimate with 100 fold | |
with conn.execute(f"SELECT COUNT(*) FROM {name} TABLESAMPLE SYSTEM (1)") as cursor: | |
row = cursor.fetchone() | |
count = row[0] * 100 | |
if count >= 100_000_000: | |
return count | |
# 10K-100M: estimate with 10 fold | |
with conn.execute(f"SELECT COUNT(*) FROM {name} TABLESAMPLE SYSTEM (10)") as cursor: | |
row = cursor.fetchone() | |
count = row[0] * 10 | |
if count >= 10_000: | |
return count | |
# 10K-: count exact rows | |
with conn.execute(f"SELECT COUNT(*) FROM {name}") as cursor: | |
row = cursor.fetchone() | |
return row[0] | |
def extract_vectors(conn, sql): | |
conn.autocommit = False | |
# must use a named ServerCursor to enable Generator feature | |
with conn.cursor(name="dump_server_cursor") as cursor: | |
cursor.itersize = 10000 | |
cursor.execute(sql) | |
for record in cursor: | |
yield record[0] | |
conn.commit() | |
conn.autocommit = True | |
def pg_to_npy(dim, filepath, name, column, lists): | |
select = MAX_POINTS_PER_CLUSTER * lists | |
logger.info(f"Dump begin") | |
table_rows = math.ceil(extract_rows(conn, name)) | |
logger.info(f"Table {name} estimate {table_rows} rows") | |
rate = math.ceil(select / table_rows * 100) | |
init_rows = min(table_rows, select) | |
result = np.zeros((init_rows, dim), dtype=np.float32) | |
logger.info( | |
f"Init samples: {select} x {dim}, will allocate: {result.nbytes/GB_TO_B:.1f}GB" | |
) | |
if rate < 100: | |
logger.info(f"Dump {select} rows, sample rate {rate}%") | |
sql = ( | |
f"SELECT {column} FROM {name} TABLESAMPLE SYSTEM({rate}) LIMIT {init_rows}" | |
) | |
else: | |
logger.info(f"Dump all rows") | |
sql = f"SELECT {column} FROM {name} LIMIT {init_rows}" | |
vecs = extract_vectors(conn, sql) | |
selected = 0 | |
for i, v in tqdm(enumerate(vecs), total=init_rows): | |
result[i, :] = v | |
selected = i | |
if selected < init_rows - 1: | |
logger.warning( | |
f"should have {init_rows} records, but only {selected} available, the result is truncated." | |
) | |
result = result[: selected + 1, :] | |
np.save(filepath, result, allow_pickle=False) | |
if __name__ == "__main__": | |
parser = build_arg_parse() | |
args = parser.parse_args() | |
print(args) | |
conn = create_connection(args.db_url) | |
pg_to_npy(args.dim, args.output, args.name, args.column, args.lists) | |
logger.info(f"Dump finish") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import datetime | |
from time import perf_counter | |
import argparse | |
from pathlib import Path | |
import logging | |
import psutil | |
from numpy import linalg as LA | |
import faiss | |
import numpy as np | |
GB_TO_B = 1024**3 | |
DEFAULT_LISTS = 4096 | |
N_ITER = 25 | |
CHUNKS = 10 | |
SEED = 42 | |
MAX_POINTS_PER_CLUSTER = 256 | |
logger = logging.getLogger() | |
logger.setLevel(logging.INFO) | |
handler = logging.StreamHandler() | |
logger.addHandler(handler) | |
class CustomFormatter(logging.Formatter): | |
def formatTime(self, record, datefmt=None): | |
t = datetime.datetime.fromtimestamp(record.created) | |
return t.strftime("%H:%M:%S") | |
def format(self, record): | |
mem = psutil.virtual_memory().used / GB_TO_B | |
record.msg = f"Mem: {mem:.1f}GB {record.msg}" | |
return super().format(record) | |
formatter = CustomFormatter("%(asctime)s - %(levelname)s - %(message)s") | |
handler.setFormatter(formatter) | |
def build_arg_parse(): | |
parser = argparse.ArgumentParser(description="Train K-means centroids") | |
parser.add_argument("-i", "--input", help="input filepath", required=True) | |
parser.add_argument("-o", "--output", help="output filepath", required=True) | |
parser.add_argument( | |
"--lists", | |
help="Number of centroids", | |
type=int, | |
required=False, | |
default=DEFAULT_LISTS, | |
) | |
parser.add_argument( | |
"--niter", help="number of iterations", type=int, default=N_ITER | |
) | |
parser.add_argument("-d", "--dim", type=int, required=True) | |
parser.add_argument("-m", "--metric", choices=["l2", "cos", "dot"], default="l2") | |
parser.add_argument( | |
"-g", "--gpu", help="enable GPU for KMeans", action="store_true" | |
) | |
return parser | |
def kmeans_cluster( | |
train, | |
k, | |
niter, | |
metric, | |
gpu=False, | |
): | |
n, dim = train.shape | |
if not k: | |
k = np.sqrt(n).astype(int) * 4 # TODO | |
if metric == "cos": | |
train = train / LA.norm(train, axis=1, keepdims=True) | |
kmeans = faiss.Kmeans( | |
dim, k, gpu=gpu, verbose=True, niter=niter, seed=SEED, spherical=metric != "l2" | |
) | |
kmeans.train(train) | |
return kmeans.centroids | |
if __name__ == "__main__": | |
parser = build_arg_parse() | |
args = parser.parse_args() | |
logger.info(args) | |
data = np.load(args.input, allow_pickle=False) | |
rows, cols = data.shape | |
logger.info(f"Input shape: {rows} x {cols}") | |
start_time = perf_counter() | |
centroids = kmeans_cluster( | |
data, | |
args.lists, | |
args.niter, | |
args.metric, | |
args.gpu, | |
) | |
logger.info( | |
f"K-means (k=({args.lists})): {perf_counter() - start_time:.2f}s" | |
) | |
np.save(Path(args.output), centroids, allow_pickle=False) |
The new script dumps laion-5m
for only 6 minutes
$ python dump.py -n laion -o laion.npy --lists 8192 -u postgresql://postgres:123@localhost:5432/postgres -d 768
Namespace(name='laion', db_url='postgresql://postgres:123@localhost:5432/postgres', output='laion.npy', column='embedding', dim=768, lists=8192)
14:26:41 - INFO - Mem: 0.4GB Dump begin
14:26:41 - INFO - Mem: 0.4GB Init samples: 2097152 x 768, will allocate: 6.0GB
14:26:41 - INFO - Mem: 0.4GB Find 4999992 rows, collect 2097152 rows, sample rate 42%
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2097152/2097152 [05:43<00:00, 6098.54it/s]
14:32:42 - INFO - Mem: 0.6GB Dump finish
$ python kmeans.py -i laion.npy -o data.npy --lists 8192 -d 768 -m dot
14:42:13 - INFO - Mem: 0.4GB Namespace(input='laion.npy', output='data.npy', lists=8192, niter=1, dim=768, metric='dot', gpu=False)
14:42:24 - INFO - Mem: 6.4GB input shape: 2097151 x 768
Clustering 2097151 points in 768D to 8192 clusters, redo 1 times, 1 iterations
Preprocessing in 1.07 s
Iteration 0 (92.44 s, search 92.01 s): objective=1.52379e+06 imbalance=2.546 nsplit=32
14:43:57 - INFO - Mem: 6.4GB K-means (k=(8192)): 93.54s
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Files
dump.py
: dump from postgres table intonumpy mmap
file(.mmap)kmeans.py
: usenumpy mmap
file(.mmap) for faiss kmeanshd2npy.py
: for test, migratexxx.hdf5
intoxxx.mmap
Break changes
-d
for dimension of vector at scriptkmeans.py
psutil
to monitor memory usage