Last active
March 5, 2023 14:39
-
-
Save sirupsen/5f81105b73c09f8425e4a3b4bd70c473 to your computer and use it in GitHub Desktop.
I run all my searches in FZF as I edit my notes in Vim. See https://share.cleanshot.com/x9ZkfBQQ -- Will require some tinkering but it runs the inference/search server so the searches are fast. It's deployed to Modal.com
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/bin/bash | |
# FZF quotes field names so I can't easily get preview to work otherwise | |
cd "$ZK_PATH" | |
# --bind="2:preview(echo '\"{q}\",\"{4..}\",2' >> judge.csv && echo Rated 2!)" \ | |
FZF_DEFAULT_COMMAND="python ~/src/semsearch/search_webhook.py learning" | |
fzf \ | |
--bind="0:preview(ruby ~/src/semsearch/write.rb {q} {4..} 0 && echo Rated 0!)" \ | |
--bind="1:preview(ruby ~/src/semsearch/write.rb {q} {4..} 1 && echo Rated 1!)" \ | |
--bind="2:preview(ruby ~/src/semsearch/write.rb {q} {4..} 2 && echo Rated 2!)" \ | |
--header="Press [0], [1], [2] to add an item to the judgement list with that rating" \ | |
--prompt "Semantic > " \ | |
--bind "change:reload(python ~/src/semsearch/search_webhook.py {q})+change-prompt(Semantic > )" \ | |
--bind "tab:reload(textgrep --scores \"{q}\")+change-prompt(BM25 > )+unbind(change)" \ | |
--bind "btab:reload(python ~/src/semsearch/search_webhook.py \"{q}\")+change-prompt(Semantic > )+rebind(change)" \ | |
--disabled \ | |
--ansi \ | |
--with-nth '1,4..' \ | |
--no-hscroll \ | |
--preview-window 'top:65%,+{3}' \ | |
--no-multi \ | |
--height 100% \ | |
--tac \ | |
--query "learning" \ | |
--preview "bat --language md --color always --plain {4..} --highlight-line {2}" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import glob | |
import multiprocessing | |
import os | |
import pathlib | |
import signal | |
import socket | |
import sys | |
import time | |
import modal | |
volume = modal.SharedVolume().persist("model-cache") | |
stub = modal.Stub( | |
"search-webhook", | |
image=modal.Image.debian_slim().pip_install(["sentence-transformers"]), | |
) | |
note_path = "~/Documents/Zettelkasten/**/*.md" | |
paragraphs = None | |
bi_encoder = None | |
cross_encoder = None | |
# write a Python function to check if a port is open on the machine | |
# https://stackoverflow.com/questions/17412304/hashing-an-array-or-object-in-python-3 | |
def checksum(data): | |
import hashlib | |
hashId = hashlib.md5() | |
hashId.update(repr(data).encode("utf-8")) | |
return hashId.hexdigest() | |
def paragraphs_from_glob(paths): | |
paragraphs = [] | |
paragraph_idx_to_path = {} | |
paragraph_idx_to_lines = {} | |
paths.sort() # stable checksum! | |
for _, path in enumerate(paths): | |
title = pathlib.Path(path) | |
new_paragraphs = [title.stem.strip()] | |
new_paragraphs.extend(title.read_text().split("\n\n")) | |
new_paragraphs = list(filter(None, new_paragraphs)) | |
line = 1 | |
for paragraph_idx, paragraph in zip( | |
range(len(paragraphs), len(paragraphs) + len(new_paragraphs)), | |
new_paragraphs, | |
): | |
assert paragraph_idx >= len(paragraphs) | |
assert paragraph | |
assert paragraph_idx not in paragraph_idx_to_path | |
paragraph_idx_to_path[paragraph_idx] = path | |
paragraph_idx_to_lines[paragraph_idx] = [ | |
line, | |
line + paragraph.count("\n"), | |
] | |
if paragraph_idx > len(paragraphs): # not the title | |
line += paragraph.count("\n") + 2 # for the \n\n | |
paragraphs.extend(new_paragraphs) | |
return { | |
"paragraphs": paragraphs, | |
"paragraph_idx_to_path": paragraph_idx_to_path, | |
"paragraph_idx_to_lines": paragraph_idx_to_lines, | |
"checksum": checksum(paragraphs), | |
} | |
def cached_models(): | |
from sentence_transformers import CrossEncoder, SentenceTransformer | |
cache_path = "/tmp/model-cache" | |
if stub.is_inside(): | |
cache_path = "/root/models" | |
os.environ["SENTENCE_TRANSFORMERS_HOME"] = cache_path # SentenceTransformer | |
os.environ["TORCH_HOME"] = cache_path # CrossEncoder | |
before = time.monotonic() | |
bi_encoder = SentenceTransformer("multi-qa-MiniLM-L6-cos-v1") | |
cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2") | |
print("Instantiate bi and cross encoders: ", time.monotonic() - before, "seconds", file=sys.stderr) | |
return bi_encoder, cross_encoder | |
@stub.function( | |
gpu=True, | |
shared_volumes={ | |
"/root/models": volume, | |
}, | |
) | |
def encode_paragraph_embeddings(paragraphs): | |
bi_encoder, _cross_encoder = cached_models() | |
# print to stderr | |
print("Encoding paragraphs", file=sys.stderr) | |
before = time.monotonic() | |
# type: ignore | |
paragraph_embeddings = bi_encoder.encode( | |
paragraphs["paragraphs"], | |
# batch_size=128, | |
show_progress_bar=True, | |
convert_to_tensor=True, | |
) | |
print("Encode paragraphs: ", time.monotonic() - before, "seconds", file=sys.stderr) | |
return paragraph_embeddings.cpu() # for serialization | |
def cached_paragraphs(key) -> dict: | |
import diskcache | |
dc = diskcache.Cache("/tmp/zk-cache/") | |
print("Cache key is", key, file=sys.stderr) | |
if key in dc: | |
before = time.monotonic() | |
cached_paragraphs = dc.get(key) | |
print("Get cached paragraph embeddings: ", time.monotonic() - before, "seconds", file=sys.stderr) | |
return cached_paragraphs | |
else: | |
print("Running stub..") | |
with stub.run(): | |
before = time.monotonic() | |
print("Getting paragraphs") | |
paragraphs = paragraphs_from_glob(glob.glob(os.path.expanduser(note_path), recursive=True)) | |
print("Load paragraphs: ", time.monotonic() - before, "seconds") | |
paragraphs["embeddings"] = encode_paragraph_embeddings(paragraphs) | |
dc.set(key, paragraphs, expire=24 * 3600 * 5) | |
dc.expire() | |
return paragraphs | |
# @stub.function( | |
# gpu=False, | |
# shared_volumes={ | |
# "/root/models": volume, | |
# }, | |
# mounts=[ | |
# modal.Mount( | |
# remote_dir="/root/notes", | |
# local_dir="~/Documents/Zettelkasten", | |
# # only allow markdown files | |
# condition=lambda path: path.endswith(".md"), | |
# recursive=True, | |
# ) | |
# ], | |
# ) | |
def search(query_string: str): | |
global paragraphs | |
global bi_encoder | |
global cross_encoder | |
from sentence_transformers import util | |
before = time.monotonic() | |
if paragraphs is None: | |
paragraphs = cached_paragraphs("expiring-paragraphs") | |
print("Get paragraph embeddings: ", time.monotonic() - before, "seconds", file=sys.stderr) | |
if bi_encoder is None: | |
bi_encoder, cross_encoder = cached_models() | |
query_string = query_string.strip() | |
query_embedding = bi_encoder.encode(query_string, convert_to_tensor=True) | |
before = time.monotonic() | |
results = util.semantic_search(query_embedding, paragraphs["embeddings"], top_k=32)[0] | |
print("Search time: ", time.monotonic() - before, file=sys.stderr) | |
deduped_results = {} | |
before = time.monotonic() | |
cross_inp = [[query_string, paragraphs["paragraphs"][result["corpus_id"]]] for result in reversed(results)] | |
cross_scores = cross_encoder.predict(cross_inp) | |
print("Cross encoding time: ", time.monotonic() - before, "seconds\n", file=sys.stderr) | |
for idx, match in enumerate(reversed(results)): | |
paragraph_id = match["corpus_id"] | |
path = paragraphs["paragraph_idx_to_path"][paragraph_id] | |
# paragraph = paragraphs[paragraph_id] | |
start_line, end_line = paragraphs["paragraph_idx_to_lines"][paragraph_id] | |
dirname = os.path.basename(os.path.dirname(path)) | |
title = os.path.basename(path) | |
if dirname == "highlights": | |
title = f"highlights/{title}" | |
cross_score = cross_scores[idx] | |
# score = match["score"] | |
score = cross_score | |
# cross_encoder_score = cross_encoder. | |
title_with_lines = f"{start_line}:{end_line} {max(start_line - 10, 0)} {title}" | |
if path not in deduped_results: | |
deduped_results[title] = [score, title_with_lines] | |
elif score > deduped_results[title][0]: | |
deduped_results[title] = [ | |
score, | |
title_with_lines, | |
] | |
else: | |
# it was another match, but score wasn't higher... should prob contribute to score though | |
# similar to TF | |
continue | |
output = "" | |
for _, match in sorted(deduped_results.items(), key=lambda item: item[1]): | |
output += f"{match[0]:.2f} {match[1]}\n" | |
# print("{:.2f} {}".format(match[0], match[1])) | |
return output[0:-1] | |
def port_open(port): | |
import socket | |
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |
result = sock.connect_ex(("127.0.0.1", port)) | |
sock.close() | |
return result == 0 | |
def simple_tcp_server(): | |
with open(os.devnull, "w") as f1, open(os.devnull, "w") as f2: | |
sys.stderr = f1 | |
sys.stdout = f2 | |
port = 8090 | |
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) | |
s.bind(("127.0.0.1", port)) | |
s.listen(5) | |
def signal_handler(signal, frame): | |
s.close() | |
sys.exit(0) | |
signal.signal(signal.SIGTERM, signal_handler) | |
print("Listening on port", port, file=sys.stderr) | |
while True: | |
c, addr = s.accept() | |
print("PID: ", os.getpid(), file=sys.stderr) | |
query = c.recv(1024) | |
if len(query) > 0: # because of the port check! | |
result = search(query.decode("utf-8")) | |
c.send(result.encode("utf-8")) | |
c.close() | |
def query_via_tcp(query): | |
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |
s.connect(("localhost", 8090)) | |
s.send(query.encode("utf-8")) | |
print(s.recv(10_000).decode("utf-8")) | |
s.close() | |
if __name__ == "__main__": | |
query = " ".join(sys.argv[1:]) | |
# if query is empty | |
if not query: | |
print("No query was provided") | |
sys.exit(0) | |
if port_open(8090): | |
query_via_tcp(query) | |
else: | |
p = multiprocessing.Process(target=simple_tcp_server, daemon=False, name="python search.py") | |
p.start() | |
for _ in range(500): | |
time.sleep(0.1) | |
if port_open(8090): | |
break | |
query_via_tcp(query) | |
os.kill(os.getpid(), signal.SIGTERM) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment