Last active
March 10, 2025 19:57
-
-
Save jaggzh/d57ef8757c10e945d99c77f91449e33a to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# embed-search | |
# Takes query text (utext) on the commandline, | |
# and many lines of text from stdin or a file (lines). | |
# Outputs the highest semantically-matching line(s) | |
# gist-paste -u https://gist.github.com/jaggzh/d57ef8757c10e945d99c77f91449e33a embed-search-wrapped | |
# Dependencies: | |
# bansi.py from https://gist.github.com/jaggzh/c41a5437d1f16696493fffdb658f68dc | |
# (and all the stuff in imports that'll get ya) | |
# One of the py libs caused a library conflict, so I currently | |
# have to have a wrapper script ('embed-search') call this one: | |
# Thus, my 'embed-search' is actually: | |
# #!/bin/bash | |
# export LD_PRELOAD=/usr/lib/gcc/x86_64-linux-gnu/12/libstdc++.so | |
# embed-search-wrapped "$@" | |
# Examples: | |
# 1. cmd | embed-search "some text" | |
# Outputs closest line from stdin matching utext | |
# Uses cached file /tmp/emb-{username}.pkl if it exists. Otherwise it | |
# Outputs cached vectors to /tmp/emb.pkl so, if run again with a different "some text" it will not have to re-generate the embedding | |
# 2. cmd | embed-search "some text" | |
# If run again, will Outputs closest line from stdin matching utext | |
# 3. cmd | embed-search -5 "some text" | |
# Outputs top 5 matching lines | |
# 4b. cmd | embed-search "some text" \ | |
# --re '\S+\t(\S+)\t(.+)' \ | |
# --ore '(\S+)\t(.+)' \ | |
# --ofmt '{1} - {0}' | |
# 4b. cmd | embed-search "some text" \ | |
# --re '\S+\t(\S+)\t(.+)' \ | |
# --ore '(\S+)\t(.+)' \ | |
# --ofmt 'Line {line}. Score: {score}. {0} - {text}' | |
# Outputs a line's 1st 'field', a dash, and the complete text line. | |
# This is similar to the default output, but shows using --ore's match groups to extract something out of the text. | |
# 5. cmd | embed-search "some text" -s | |
# Uses generated cache name "/tmp/embc-{username}-{hash}.pkl" (hash is of all lines from input). Thus, if cmd's output has not changed and this is run twice, vectorization will not re-occur. | |
# 6. cmd | embed-search "some text" \ | |
# --re '\S+\s(.+)' \ | |
# --ofmt "[{line}] {score}: {0}" | |
# This shows that, because --re was provided, {0} is available from its match groups, to --ofmt. | |
# 9. cmd | embed-search "some text" --ofmt "[{line}] {score}: {0}" | |
# This will automatically error when ofmt.format(line=linen, score=score, *mgroups) is called, because --re and --ore were not provided, so no match groups will be available for {0}. This error is intended and we let it happen. | |
# Options: | |
# -n / --count # - Count of results to display, sorted (default 1) | |
# -f / --textfile - Text file of lines (Default: '-' reads from STDIN) | |
# -1 .. -9 - Shorthand for single digit -n # | |
# -0 - Unlimited -- rank and output all lines sorted | |
# -m / --model str - Model name for sentence transformers. (Default: 'all-MiniLM-L6-v2') | |
# -h / --help | |
# -v / --verbose - Increase verbosity (stored as a global var) | |
# -C / --nocache - Default caches the result vector DB. This will disable writing of it | |
# --cfile str - Full path/filename to store cached pickle. | |
# If no hashing (-s) is selected, this defaults to: | |
# '/tmp/embc-{username}.pkl' | |
# With -s / --cachehash: | |
# f'/tmp/embc-{username}-{hash}.pkl' | |
# (If you manually specify --cfile, you can use a plain filename, or include the vars username or hash to be filled in. This uses .format(), not f-strings, for security.) | |
# -s / --cachehash - Generate a hash checksum (currently uses Murmurhash3), of all input lines combined as one. | |
# --re / --re_input str - Regex to run on lines of input text to select vectorized text. The first match group is used as the text to vectorize. Other groups are ignored. (Default, None, means to use the full line of text (with newline stripped).) | |
# --ore / --output_re str - Final output lines are processed with this regex. (Default, None. Whole line is output. Ex. '(\S+)\t(.*)'. Requires --ofmt to be provided. Warns if it's not; ie. pel(0, "--ore specified but no --ofmt. Output will be full line."); | |
# --ofmt / --output_fmt str - Output formatting. Can use named placeholders: {line} (line number), {score}, and {text}. If --re or --ore are provided, their match groups are provided as {0}, ...{n}. --ore takes precedence over --re for the provided groups. (Default, "[{line}] Score: {score:.4f}. {line}") | |
# A. If -s is not specified, cfile is used. Otherwise the cbase method of filename generation is used. | |
# B. BEFORE vectorization, the parent dirs of cfile and cbase are tested for their existence. An error is output if the directory does not exist and we exit with an error code. | |
# C. Notice in the code the safety where, when the default location(s) (/tmp/) are used, umask is set when the file is created/written, then reset (omask=os.umask(); save; os.umask(omask). | |
# D. -b / --build is no longer needed; vectorization is done if the cached db is not found. | |
# E. -h will display normal usage help= texts as well as have room for our lengthy usage/examples text to be added. Provide a couple line placeholder of Examples: for now. | |
# F.i. If user_text is not provided, the cache is checked for existence, and built and saved if necessary. | |
# F.ii. If -C is enabled (disable caching), and no user text is provided, an error is output, pel(0,"No query, and storing is disabled. We will still vectorize, but won't store it.") | |
# G. Additional output: | |
# Keep stats on line counts. Eg. Total, lines matching --re if used, lines not matching --re, lines matching --ore, and lines failing. These can be useful when the user is messing with regex. OUTPUT THESE STATS ONLY IF -V IS USED (and to STDERR is fine). | |
# H. It's only designed for linux right now. You can use path join(), but otherwise stick with /tmp/ and slashes. | |
# I. Use pel() levels for the amount of output verbosity. pel(0,...) is useful for outputting errors that need to be seen (ie. no -v specified). -v enables data that might be useful for the user to debug their options. -v -v displays file lines as processed, for now. Can modify/improve this later. | |
#!/usr/bin/env python3 | |
import argparse | |
import re | |
import sys | |
import numpy as np | |
import os | |
import hashlib | |
import logging | |
import pickle | |
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # Before tf comes in | |
logging.getLogger('tensorflow').setLevel(logging.ERROR) | |
import tensorflow as tf | |
from sentence_transformers import SentenceTransformer, util, models | |
from bansi import * | |
verbose = 0 | |
def_model = os.path.expanduser("~/data/models/text/all-MiniLM-L6-v2") | |
def_cachedir = os.path.expanduser("~/.cache/embed-search") | |
def_cnt = 9 | |
os.makedirs(def_cachedir, exist_ok=True) | |
# Utility functions for printing errors and verbose information | |
def pe(*x, **y): | |
print(*x, **y, file=sys.stderr) | |
def pel(l, *x, **y): | |
if verbose >= l: | |
print(*x, **y, file=sys.stderr) | |
def is_older_than(file1, file2, noerror=False): | |
# Returns true if file1 is older than file2. | |
# noerror: Will prevent an error being raised if one of the files doesn't exist | |
# (and False will be returned) | |
try: | |
file1_mtime = os.path.getmtime(file1) | |
file2_mtime = os.path.getmtime(file2) | |
return file1_mtime < file2_mtime | |
except FileNotFoundError as e: | |
if noerror: | |
return False | |
else: | |
raise e | |
def preprocess_file(file_input, line_count, overlap=True, *, regex, cmd): | |
"""Reads and preprocesses the text input based on a regex and combines lines.""" | |
# overlap: If True, we go line by line, even if taking line_count lines at a time | |
# (This is if you're grouping into lines of, say, 3.) | |
# Default is most-sensical (True), where we allow the overlap so each | |
# line forms its own match with the subsequent lines. | |
if file_input == '-': | |
data = sys.stdin.readlines() | |
else: | |
with open(file_input, 'r', encoding='utf-8') as file: | |
data = file.readlines() | |
# Strip newlines | |
data = [line.rstrip('\n') for line in data] | |
if regex: pattern = re.compile(regex) | |
texts = [] | |
orig_texts = [] | |
num_lines = len(data) | |
if not overlap: | |
line_skip = line_count | |
else: | |
line_skip = 1 | |
# Pass 1/3: Makes original combined texts: | |
for start_idx in range(0, num_lines, line_skip): | |
chunk = data[start_idx:start_idx + line_count] | |
combined_text = "\n".join(chunk).strip() | |
orig_texts.append(combined_text) | |
# Pass 2/3: Run regex/cmd, replacing all texts: | |
if cmd or regex: | |
for idx in range(0, num_lines): | |
if regex: | |
match = pattern.search(data[idx]) | |
if match: | |
data[idx] = match.group(1) | |
if cmd: | |
try: | |
pel(3, f"Line: '{data[idx]}'") | |
local_vars = {'s': data[idx]} | |
pel(3, f"Evaluating command '{cmd}'") | |
pel(3, f"STR: s <- '{local_vars['s']}'") | |
exec(cmd, globals(), local_vars) | |
s = local_vars['s'] | |
pel(3, f"STR: s -> '{s}'") | |
pel(3, f"[{idx}] In <- {{{data[idx]}}}") | |
data[idx] = s | |
pel(3, f"[{idx}] Out <- {{{data[idx]}}}") | |
except Exception as e: | |
pe(f"Error processing line with command '{cmd}': {e}") | |
if idx>1000: sys.exit() | |
# Pass 3/3: Generate final combined versions from modified data | |
for start_idx in range(0, num_lines, line_skip): | |
chunk = data[start_idx:start_idx + line_count] | |
combined_text = "\n".join(chunk).strip() | |
texts.append(combined_text) | |
return texts, orig_texts | |
def build_embeddings(model_name, texts, output_path=None): | |
"""Generates embeddings and optionally stores them if output_path is provided.""" | |
if os.path.isdir(model_name): | |
pel(2, f"Loading LOCAL model from path ({model_name})...") | |
pel(2, f" Loading tokenizer model ({model_name})...") | |
word_embedding_model = models.Transformer(model_name) | |
pel(2, f" Loading pooling model ({model_name})...") | |
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension()) | |
pel(2, f" Assemble the sentence transformer model...") | |
model = SentenceTransformer(modules=[word_embedding_model, pooling_model]) | |
else: | |
pel(2, f"Loading repo location model from hugginface ({model_name})...") | |
model = SentenceTransformer(model_name) | |
pel(2, f"Encoding all texts (cnt:{len(texts)})...") | |
embeddings = model.encode(texts, convert_to_tensor=True) | |
if output_path: | |
with open(output_path, 'wb') as file: | |
pickle.dump((model, embeddings), file) | |
pel(1, f"Embeddings stored at: {output_path}") | |
return model, embeddings # Return embeddings for in-memory use | |
def find_top_matches(*, query, model, embeddings, original_texts, count): | |
"""Finds and returns top matches based on the query.""" | |
pel(2, f"Encoding query...") | |
pel(2, f" Text: {{{query}}}") | |
query_embedding = model.encode([query], convert_to_tensor=True) | |
pel(2, f" Embedding: {{{query_embedding}}}") | |
pel(2, f"Running cos similarity on all embeddings...") | |
scores = util.cos_sim(query_embedding, embeddings)[0] | |
scores_np = scores.cpu().numpy() | |
top_results = np.argsort(-scores_np)[:count] | |
# Replace newlines in matches with "\n " | |
try: | |
formatted_results = [ | |
( | |
index + 1, | |
original_texts[index], | |
scores_np[index] | |
) | |
for index in top_results | |
] | |
except Exception as e: | |
raise(ValueError(f"There was an error when summarizing the results. It's likely your cached embedding wasn't updated. {e}")) | |
import ipdb; ipdb.set_trace(context=7); pass | |
sys.exit(1) | |
return formatted_results | |
def load_embeddings(file_path): | |
"""Loads embeddings and model from a given pickle file path.""" | |
try: | |
with open(file_path, 'rb') as file: | |
model, embeddings = pickle.load(file) | |
pel(1, f"Loaded model and embeddings from {file_path}") | |
return model, embeddings | |
except Exception as e: | |
pe(f"Failed to load embeddings from {file_path}: {str(e)}") | |
sys.exit(1) | |
def main(): | |
parser = argparse.ArgumentParser( | |
description="Vectorize and lookup texts based on semantic processing with SBERT.", | |
formatter_class=argparse.RawDescriptionHelpFormatter, | |
epilog="Examples:\n" | |
" cmd | embed-search \"some text\"\n" | |
" embed-search -f mydata.txt \"some text\"\n" | |
" embed-search -f mydata.txt -c 's=s.upper()'\n" | |
" embed-search -f mydata.txt --re '\d+ (.*)'\n" | |
" embed-search -f mydata.txt -d mydata.txt.pkl \"some text\"\n" | |
" embed-search --nocache --model 'bert-base-nli-mean-tokens' -n 3 \"query text\"\n" | |
"DB storage:\n" | |
" To prevent name collisions, the full path to the -f file is converted to a single file name, ending in .pkl.\n" | |
" Specifying -d overrides this.\n" | |
f"(Default db cache path: {def_cachedir}/your_text_file_path.pkl)" | |
) | |
# REM: parser.add_argument('-c', '--cache_dir', type=str, help=f"Dir for cached dbs ({def_cachedir})") | |
parser.add_argument('-f', '--textfile', default='-', help="Text file of lines (Default: '-' reads from STDIN)") | |
parser.add_argument('-d', '--db_path', type=str, help=f"Store/Use this FILE as DB instead of storing in the cache path at {def_cachedir}.") | |
parser.add_argument('-m', '--model', default=def_model, help=f"Model name for sentence transformers ({def_model}).") | |
parser.add_argument('-c', '--cmd', type=str, help="Line proc: Python command to process each line of text.") | |
parser.add_argument('--re', default='(.*)', help="Regular expression to FILTER the incoming text lines (match group 1 is used)") | |
parser.add_argument('user_text', nargs='?', help="User text for finding closest match.") | |
parser.add_argument('-n', '--count', type=int, default=def_cnt, help=f"Number of top matches to return ({def_cnt})") | |
parser.add_argument('-v', '--verbose', action='count', default=0, help="Increase verbosity.") | |
parser.add_argument('-C', '--nocache', action='store_true', help="Do not cache results to disk.") | |
parser.add_argument('-O', '--overwrite', action='store_true', help="Overwrite cache if it exists.") | |
parser.add_argument('-l', '--line_count', type=int, default=1, help="Number of lines to group together for processing.") | |
args = parser.parse_args() | |
global verbose | |
verbose = args.verbose | |
def file_to_cachename_plain(file): | |
file = os.path.realpath(file) | |
file = file.replace('/', '%') | |
return file | |
def file_to_cachepath(file): | |
file = file_to_cachename_plain(file) | |
return os.path.join(def_cachedir, file + ".pkl") | |
uset_dbpath = False | |
if args.db_path: # we have in both vars so we can distinguish | |
uset_dbpath = True | |
db_path = args.db_path | |
else: | |
db_path = file_to_cachepath(args.textfile) | |
if not os.path.exists(args.model): | |
pel(0, f"{red}Warning: Model (-m '{args.model}') doesn't seem to be a file.{rst}") | |
pel(0, f"{red}If you're building the DB we're probably going to error.{rst}") | |
pel(1, f"Cache db: {db_path}") | |
pel(1, "Preprocessing file...") | |
texts, original_texts = preprocess_file(args.textfile, regex=args.re, line_count=args.line_count, cmd=args.cmd) | |
if args.nocache: | |
# Process everything in memory, do not write to disk | |
pel(1, "Building embeddings (NOT CACHED)...") | |
model, embeddings = build_embeddings(args.model, texts) | |
else: | |
# Determine cache file and process accordingly | |
if args.overwrite and os.path.exists(db_path): | |
pel(1, f"Removing existing db cache (-O) ({db_path}).") | |
os.unlink(db_path) | |
if not os.path.exists(db_path): | |
pel(1, f"Building embeddings (cached to {db_path}))...") | |
model, embeddings = build_embeddings(args.model, texts, db_path) | |
else: | |
pel(1, f"Loading existing embeddings ({db_path}))...") | |
# Test text file time against cache for safety: | |
if args.textfile and args.textfile != '-': | |
if is_older_than(db_path, args.textfile): | |
pe(f"{yel}WARNING: Cached embeddings are older than your text file.{rst}") | |
pe(f"{yel} Add -O to automatically update it{rst}") | |
# pe(f" cache db: {db_path}") | |
# pe(f" textfile: {args.textfile}") | |
model, embeddings = load_embeddings(db_path) | |
if args.user_text: | |
pel(1, f"Finding matches with user text {{{args.user_text}}}") | |
matches = find_top_matches(query=args.user_text, | |
model=model, | |
embeddings=embeddings, | |
original_texts=original_texts, | |
count=args.count, | |
) | |
for index, match, score in matches: | |
pfx=f"Match at line {whi}{index}{rst}. Score:{yel}{score:.4f}{rst} Text:" | |
txtcolor=a24bg(10,20,34)+a24fg(255,150,245) | |
if args.line_count==1: | |
print(f"{pfx}{{{txtcolor}{match}{rst}}}") | |
else: | |
pfxlen=ansilen(pfx) | |
matches = match.split('\n') | |
print(pfx) | |
for m in matches: | |
print(f" {txtcolor}{m}{rst}") | |
else: | |
pe("Warning: No text provided for lookup.") | |
sys.exit(1) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment