Skip to content

Instantly share code, notes, and snippets.

@cutecutecat
Last active January 23, 2025 07:12
Show Gist options
  • Save cutecutecat/4e24fe74a83bd9c96b04d78938c4f505 to your computer and use it in GitHub Desktop.
Save cutecutecat/4e24fe74a83bd9c96b04d78938c4f505 to your computer and use it in GitHub Desktop.
kmeans toolkit
import argparse
import datetime
import logging
import math
import psutil
import numpy as np
import psycopg
from tqdm import tqdm
from pgvector.psycopg import register_vector
GB_TO_B = 1024**3
MAX_POINTS_PER_CLUSTER = 256
logger = logging.getLogger()
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
logger.addHandler(handler)
class CustomFormatter(logging.Formatter):
def formatTime(self, record, datefmt=None):
t = datetime.datetime.fromtimestamp(record.created)
return t.strftime("%H:%M:%S")
def format(self, record):
mem = psutil.virtual_memory().used / GB_TO_B
record.msg = f"Mem: {mem:.1f}GB {record.msg}"
return super().format(record)
formatter = CustomFormatter("%(asctime)s - %(levelname)s - %(message)s")
handler.setFormatter(formatter)
def build_arg_parse():
parser = argparse.ArgumentParser(description="Dump embeddings to a local file")
parser.add_argument("-n", "--name", help="table name", required=True)
parser.add_argument(
"-u",
"--db-url",
help="Database URL",
default="postgresql://postgres:123@localhost:5432/postgres",
)
parser.add_argument("-o", "--output", help="Output filepath", required=True)
parser.add_argument("-c", "--column", help="Column name", default="embedding")
parser.add_argument("-d", "--dim", type=int, required=True)
parser.add_argument("--lists", help="Number of centroids", type=int, required=True)
return parser
def create_connection(db_url):
keepalive_kwargs = {
"keepalives": 1,
"keepalives_idle": 30,
"keepalives_interval": 5,
"keepalives_count": 5,
}
conn = psycopg.connect(
conninfo=db_url,
autocommit=True,
**keepalive_kwargs,
)
# Disable audit to avoid large WAL
conn.set_read_only(True)
register_vector(conn)
return conn
def extract_rows(conn, name):
# 100M+: estimate with 100 fold
with conn.execute(f"SELECT COUNT(*) FROM {name} TABLESAMPLE SYSTEM (1)") as cursor:
row = cursor.fetchone()
count = row[0] * 100
if count >= 100_000_000:
return count
# 10K-100M: estimate with 10 fold
with conn.execute(f"SELECT COUNT(*) FROM {name} TABLESAMPLE SYSTEM (10)") as cursor:
row = cursor.fetchone()
count = row[0] * 10
if count >= 10_000:
return count
# 10K-: count exact rows
with conn.execute(f"SELECT COUNT(*) FROM {name}") as cursor:
row = cursor.fetchone()
return row[0]
def extract_vectors(conn, sql):
conn.autocommit = False
# must use a named ServerCursor to enable Generator feature
with conn.cursor(name="dump_server_cursor") as cursor:
cursor.itersize = 10000
cursor.execute(sql)
for record in cursor:
yield record[0]
conn.commit()
conn.autocommit = True
def pg_to_npy(dim, filepath, name, column, lists):
select = MAX_POINTS_PER_CLUSTER * lists
logger.info(f"Dump begin")
table_rows = math.ceil(extract_rows(conn, name))
logger.info(f"Table {name} estimate {table_rows} rows")
rate = math.ceil(select / table_rows * 100)
init_rows = min(table_rows, select)
result = np.zeros((init_rows, dim), dtype=np.float32)
logger.info(
f"Init samples: {select} x {dim}, will allocate: {result.nbytes/GB_TO_B:.1f}GB"
)
if rate < 100:
logger.info(f"Dump {select} rows, sample rate {rate}%")
sql = (
f"SELECT {column} FROM {name} TABLESAMPLE SYSTEM({rate}) LIMIT {init_rows}"
)
else:
logger.info(f"Dump all rows")
sql = f"SELECT {column} FROM {name} LIMIT {init_rows}"
vecs = extract_vectors(conn, sql)
selected = 0
for i, v in tqdm(enumerate(vecs), total=init_rows):
result[i, :] = v
selected = i
if selected < init_rows - 1:
logger.warning(
f"should have {init_rows} records, but only {selected} available, the result is truncated."
)
result = result[: selected + 1, :]
np.save(filepath, result, allow_pickle=False)
if __name__ == "__main__":
parser = build_arg_parse()
args = parser.parse_args()
print(args)
conn = create_connection(args.db_url)
pg_to_npy(args.dim, args.output, args.name, args.column, args.lists)
logger.info(f"Dump finish")
import datetime
from time import perf_counter
import argparse
from pathlib import Path
import logging
import psutil
from numpy import linalg as LA
import faiss
import numpy as np
GB_TO_B = 1024**3
DEFAULT_LISTS = 4096
N_ITER = 25
CHUNKS = 10
SEED = 42
MAX_POINTS_PER_CLUSTER = 256
logger = logging.getLogger()
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
logger.addHandler(handler)
class CustomFormatter(logging.Formatter):
def formatTime(self, record, datefmt=None):
t = datetime.datetime.fromtimestamp(record.created)
return t.strftime("%H:%M:%S")
def format(self, record):
mem = psutil.virtual_memory().used / GB_TO_B
record.msg = f"Mem: {mem:.1f}GB {record.msg}"
return super().format(record)
formatter = CustomFormatter("%(asctime)s - %(levelname)s - %(message)s")
handler.setFormatter(formatter)
def build_arg_parse():
parser = argparse.ArgumentParser(description="Train K-means centroids")
parser.add_argument("-i", "--input", help="input filepath", required=True)
parser.add_argument("-o", "--output", help="output filepath", required=True)
parser.add_argument(
"--lists",
help="Number of centroids",
type=int,
required=False,
default=DEFAULT_LISTS,
)
parser.add_argument(
"--niter", help="number of iterations", type=int, default=N_ITER
)
parser.add_argument("-d", "--dim", type=int, required=True)
parser.add_argument("-m", "--metric", choices=["l2", "cos", "dot"], default="l2")
parser.add_argument(
"-g", "--gpu", help="enable GPU for KMeans", action="store_true"
)
return parser
def kmeans_cluster(
train,
k,
niter,
metric,
gpu=False,
):
n, dim = train.shape
if not k:
k = np.sqrt(n).astype(int) * 4 # TODO
if metric == "cos":
train = train / LA.norm(train, axis=1, keepdims=True)
kmeans = faiss.Kmeans(
dim, k, gpu=gpu, verbose=True, niter=niter, seed=SEED, spherical=metric != "l2"
)
kmeans.train(train)
return kmeans.centroids
if __name__ == "__main__":
parser = build_arg_parse()
args = parser.parse_args()
logger.info(args)
data = np.load(args.input, allow_pickle=False)
rows, cols = data.shape
logger.info(f"Input shape: {rows} x {cols}")
start_time = perf_counter()
centroids = kmeans_cluster(
data,
args.lists,
args.niter,
args.metric,
args.gpu,
)
logger.info(
f"K-means (k=({args.lists})): {perf_counter() - start_time:.2f}s"
)
np.save(Path(args.output), centroids, allow_pickle=False)
@cutecutecat
Copy link
Author

Files

  • dump.py: dump from postgres table into numpy mmap file(.mmap)
  • kmeans.py: use numpy mmap file(.mmap) for faiss kmeans
  • hd2npy.py: for test, migrate xxx.hdf5 into xxx.mmap

Break changes

  • Add a new argument -d for dimension of vector at script kmeans.py
  • Add a new dependency psutil to monitor memory usage

@cutecutecat
Copy link
Author

cutecutecat commented Jan 21, 2025

The new script dumps laion-5m for only 6 minutes

$ python dump.py -n laion -o laion.npy --lists 8192 -u postgresql://postgres:123@localhost:5432/postgres -d 768
Namespace(name='laion', db_url='postgresql://postgres:123@localhost:5432/postgres', output='laion.npy', column='embedding', dim=768, lists=8192)
14:26:41 - INFO - Mem: 0.4GB Dump begin
14:26:41 - INFO - Mem: 0.4GB Init samples: 2097152 x 768, will allocate: 6.0GB
14:26:41 - INFO - Mem: 0.4GB Find 4999992 rows, collect 2097152 rows, sample rate 42%
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2097152/2097152 [05:43<00:00, 6098.54it/s]
14:32:42 - INFO - Mem: 0.6GB Dump finish

$ python kmeans.py -i laion.npy -o data.npy --lists 8192 -d 768 -m dot
14:42:13 - INFO - Mem: 0.4GB Namespace(input='laion.npy', output='data.npy', lists=8192, niter=1, dim=768, metric='dot', gpu=False)
14:42:24 - INFO - Mem: 6.4GB input shape: 2097151 x 768
Clustering 2097151 points in 768D to 8192 clusters, redo 1 times, 1 iterations
  Preprocessing in 1.07 s
  Iteration 0 (92.44 s, search 92.01 s): objective=1.52379e+06 imbalance=2.546 nsplit=32       
14:43:57 - INFO - Mem: 6.4GB K-means (k=(8192)): 93.54s

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment