Skip to content

Instantly share code, notes, and snippets.

@cutecutecat
Last active September 5, 2024 14:07
Show Gist options
  • Save cutecutecat/bd0897eb9adea4ab3851494c00e9b514 to your computer and use it in GitHub Desktop.
Save cutecutecat/bd0897eb9adea4ab3851494c00e9b514 to your computer and use it in GitHub Desktop.
Citus with pgvecto.rs
import numpy as np
from pgvecto_rs.psycopg import register_vector
import psycopg
# generate random data
rows = 100000
dimensions = 128
embeddings = np.random.rand(rows, dimensions)
categories = np.random.randint(100, size=rows).tolist()
queries = np.random.rand(10, dimensions)
# enable extensions
conn = psycopg.connect(
conninfo="postgres://postgres:123@localhost:5432/postgres",
dbname="postgres",
autocommit=True,
)
conn.execute("CREATE EXTENSION IF NOT EXISTS citus")
conn.execute("CREATE EXTENSION IF NOT EXISTS vectors")
conn.execute("ALTER DATABASE postgres SET hnsw.ef_search = 20")
conn.close()
# reconnect for updated GUC variables to take effect
conn = psycopg.connect(
conninfo="postgres://postgres:123@localhost:5432/postgres",
dbname="postgres",
autocommit=True,
)
register_vector(conn)
print("Creating distributed table")
conn.execute("DROP TABLE IF EXISTS items")
conn.execute(
"CREATE TABLE items (id bigserial, embedding vector(%d), category_id bigint, PRIMARY KEY (id, category_id))"
% dimensions
)
conn.execute("SET citus.shard_count = 4")
conn.execute("SELECT create_distributed_table('items', 'category_id')")
print("Loading data in parallel")
with conn.cursor().copy(
"COPY items (embedding, category_id) FROM STDIN WITH (FORMAT BINARY)"
) as copy:
copy.set_types(["vector", "bigint"])
for i in range(rows):
copy.write_row([embeddings[i], categories[i]])
while conn.pgconn.flush() == 1:
pass
print("Creating index in parallel")
conn.execute("CREATE INDEX ON items USING vectors (embedding vector_l2_ops) WITH (options = \"[indexing.hnsw]\")")
print("Running distributed queries")
for query in queries:
items = conn.execute(
"SELECT id FROM items ORDER BY embedding <-> %s LIMIT 10", (query,)
).fetchall()
print([r[0] for r in items])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment