Last active
August 15, 2023 15:30
-
-
Save kordless/aae99946e7e2a5afccc83f3c4eeee65a to your computer and use it in GitHub Desktop.
Instructor Embeddings w/FeatureBase
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# tokens from https://cloud.featurebase.com/configuration/api-keys | |
featurebase_token = "<token>" | |
# featurebase ($300 free credit on signup) | |
# https://query.featurebase.com/v2/databases/bc355-t-t-t-362c1416/query/sql (but remove /query/sql) | |
featurebase_endpoint = "https://query.featurebase.com/v2/databases/<uuid-only-no-query-sql>" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import sys | |
import random | |
import string | |
import time | |
import requests | |
from string import Template | |
import config | |
# parse helper | |
def find_between(s, first, last): | |
try: | |
start = s.index( first ) + len( first ) | |
end = s.index( last, start ) | |
return s[start:end] | |
except ValueError: | |
return "" | |
# random strings | |
def random_string(size=6, chars=string.ascii_letters + string.digits): | |
return ''.join(random.choice(chars) for _ in range(size)) | |
############### | |
# FeatureBase # | |
############### | |
def apply_schema(list_of_lists, schema): | |
result = [] | |
for row in list_of_lists: | |
dict_row = {} | |
for i, val in enumerate(row): | |
dict_row[schema[i]] = val | |
result.append(dict_row) | |
return result | |
# "sql" key in document should have a valid query | |
def featurebase_query(document): | |
# try to run the query | |
try: | |
sql = document.get("sql") | |
# Specify the file path where you want to save the SQL string | |
""" | |
file_path = "output.sql" | |
# Open the file in write mode | |
with open(file_path, "a") as file: | |
# Write the SQL string to the file | |
file.write("%s\n" % sql) | |
""" | |
result = requests.post( | |
config.featurebase_endpoint+"/query/sql", | |
data=sql.encode('utf-8'), | |
headers={ | |
'Content-Type': 'text/plain', | |
'X-API-Key': '%s' % config.featurebase_token, | |
} | |
).json() | |
except Exception as ex: | |
# bad query? | |
exc_type, exc_obj, exc_tb = sys.exc_info() | |
document['error'] = "%s: %s" % (exc_tb.tb_lineno, ex) | |
return document | |
if result.get('error', ""): | |
# featurebase reports and error | |
document['explain'] = "Error returned by FeatureBase: %s" % result.get('error') | |
document['error'] = result.get('error') | |
document['data'] = result.get('data') | |
elif result.get('data', []): | |
# got some data back from featurebase | |
document['data'] = result.get('data') | |
document['schema'] = result.get('schema') | |
field_names = [] | |
for field in result.get('schema').get('fields'): | |
field_names.append(field.get('name')) | |
document['results'] = apply_schema(result.get('data'), field_names) | |
else: | |
document['explain'] = "Query was successful, but returned no data." | |
return document | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# run | |
# python3 lodge.py | |
""" | |
0.083649576 DeWalt DCD771C2 20V MAX Cordless Drill/Driver Kit | |
0.11123812 Craftsman CMCS300B V20 Cordless Circular Saw | |
0.11976445 Milwaukee M18 FUEL 1/2 inch Hammer Drill/Driver Kit | |
0.12040126 Ryobi P884 18V ONE+ Cordless 6-Tool Combo Kit | |
0.124932826 Black+Decker BDINF20C 20V MAX Cordless Inflator | |
0.12507331 Kobalt 24-Volt Max Variable Speed Brushless Cordless Reciprocating Saw | |
0.12521183 Makita XT505 18V LXT Lithium-Ion Cordless Combo Kit | |
0.13389325 Dremel 3000 Variable Speed Rotary Tool Kit | |
0.13613683 Worx WX081L ZipSnip Cordless Electric Scissors | |
0.14397514 Porter-Cable PCE605K52 Oscillating Multi-Tool Kit | |
0.14637959 Ridgid R4021 7 inch Portable Tile Saw | |
0.150545 Husky 3/8 inch Drive Mechanics Tool Set (30-Piece) | |
0.1546238 Kreg K4 Pocket Hole System | |
0.15909016 Estwing E3-16S 16 oz. Straight Claw Hammer | |
0.16314942 Husky 268-Piece Mechanics Tool Set | |
0.16853327 Werner D6228-2 28 ft. Fiberglass Extension Ladder | |
0.16925275 Bosch GPL5 5-Point Self-Leveling Alignment Laser | |
0.17692894 ECHO PB-580T 215 mph 510 CFM 58.2cc Gas Backpack Blower | |
0.17705077 Stanley FatMax 25 ft. Tape Measure | |
0.1802178 IRWIN QUICK-GRIP 4-Pack Clamp Set | |
""" | |
from InstructorEmbedding import INSTRUCTOR | |
import torch | |
| |
model = INSTRUCTOR('hkunlp/instructor-large') | |
| |
# sample product text descriptions | |
tools_array = [ | |
"DeWalt DCD771C2 20V MAX Cordless Drill/Driver Kit", | |
"Ryobi P884 18V ONE+ Cordless 6-Tool Combo Kit", | |
"Milwaukee M18 FUEL 1/2 inch Hammer Drill/Driver Kit", | |
"Makita XT505 18V LXT Lithium-Ion Cordless Combo Kit", | |
"Husky 268-Piece Mechanics Tool Set", | |
"Kobalt 24-Volt Max Variable Speed Brushless Cordless Reciprocating Saw", | |
"Ridgid R4021 7 inch Portable Tile Saw", | |
"Bosch GPL5 5-Point Self-Leveling Alignment Laser", | |
"Craftsman CMCS300B V20 Cordless Circular Saw", | |
"Werner D6228-2 28 ft. Fiberglass Extension Ladder", | |
"Black+Decker BDINF20C 20V MAX Cordless Inflator", | |
"Husky 3/8 inch Drive Mechanics Tool Set (30-Piece)", | |
"Dremel 3000 Variable Speed Rotary Tool Kit", | |
"ECHO PB-580T 215 mph 510 CFM 58.2cc Gas Backpack Blower", | |
"Kreg K4 Pocket Hole System", | |
"Porter-Cable PCE605K52 Oscillating Multi-Tool Kit", | |
"Stanley FatMax 25 ft. Tape Measure", | |
"IRWIN QUICK-GRIP 4-Pack Clamp Set", | |
"Estwing E3-16S 16 oz. Straight Claw Hammer", | |
"Worx WX081L ZipSnip Cordless Electric Scissors" | |
] | |
| |
# calculate embeddings | |
embeddings = model.encode(tools_array, output_value="sentence_embedding").tolist() | |
| |
from database import featurebase_query | |
| |
# drop the table | |
sql = "DROP TABLE products;" | |
print(featurebase_query({"sql": sql})) | |
| |
# create the table | |
sql = "CREATE TABLE products (_id id, description string, dabed vector(768));" | |
print(featurebase_query({"sql": sql})) | |
| |
# insert into FeatureBase | |
for i, dabed in enumerate(embeddings): | |
_id = i + 1 | |
sql = f"INSERT INTO products VALUES({_id}, '{tools_array[i]}', {dabed});" | |
featurebase_query({"sql": sql}) | |
| |
s_embedding = model.encode(["cordless drill"]).tolist() | |
| |
sql = f"SELECT _id, description, cosine_distance({s_embedding[0]}, dabed) AS rank FROM products ORDER BY 2 desc;" | |
results = featurebase_query({"sql": sql}) | |
| |
sorted_results = sorted(results.get('results'), key=lambda x: x['rank']) | |
| |
for result in sorted_results: | |
print(result.get('rank'), result.get('description')) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# pip3 install -r requirements.txt | |
InstructorEmbedding==1.0.1 | |
torch==2.0.1 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment