Created
April 4, 2025 17:26
-
-
Save farleyknight/454f9e1d96842bdfa8044a57950f14b0 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
""" | |
Code Search Tool - A simple command-line tool to search Python code semantically. | |
This tool indexes Python files in a directory, creates embeddings, and returns the top 5 | |
most relevant files for a given query. | |
""" | |
import os | |
import sys | |
import argparse | |
import glob | |
import ast | |
import re | |
from typing import List, Dict, Tuple, Any | |
import lancedb | |
import numpy as np | |
import openai | |
import time | |
import json | |
import dotenv | |
dotenv.load_dotenv() | |
# You'll need to set your OpenAI API key | |
# openai.api_key = "your-api-key-here" | |
# Or set it as an environment variable: | |
# os.environ["OPENAI_API_KEY"] = "your-api-key-here" | |
def extract_code_chunks(file_path: str) -> List[Dict[str, Any]]: | |
""" | |
Extract code chunks from a Python file using Python's AST module. | |
Args: | |
file_path: Path to the Python file | |
Returns: | |
List of dictionaries containing code chunks with metadata | |
""" | |
try: | |
with open(file_path, 'r', encoding='utf-8') as f: | |
content = f.read() | |
lines = content.split('\n') | |
except Exception as e: | |
print(f"Error reading {file_path}: {e}") | |
return [] | |
chunks = [] | |
try: | |
# Parse the Python code | |
tree = ast.parse(content) | |
# Extract classes and functions | |
for node in ast.walk(tree): | |
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)): | |
# Get node type | |
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): | |
node_type = 'function' | |
else: | |
node_type = 'class' | |
# Get name | |
name = node.name | |
# Get line numbers | |
start_line = node.lineno | |
end_line = 0 | |
# Find the end line by looking at the last child's line number | |
for child in ast.walk(node): | |
if hasattr(child, 'lineno'): | |
end_line = max(end_line, child.lineno) | |
# If we couldn't determine the end line, estimate it | |
if end_line < start_line: | |
end_line = start_line + 1 | |
# Get the code as text | |
node_text = '\n'.join(lines[start_line-1:end_line]) | |
# Skip if too short | |
if len(node_text.strip()) < 10: | |
continue | |
# Add to chunks | |
chunks.append({ | |
'file_path': file_path, | |
'node_type': node_type, | |
'name': name, | |
'code': node_text, | |
'start_line': start_line, | |
'end_line': end_line, | |
}) | |
except SyntaxError as e: | |
print(f"Syntax error in {file_path}: {e}") | |
# If no chunks found (e.g., script with no functions/classes), add the whole file | |
if not chunks: | |
chunks.append({ | |
'file_path': file_path, | |
'node_type': 'script', | |
'name': os.path.basename(file_path), | |
'code': content, | |
'start_line': 1, | |
'end_line': len(lines), | |
}) | |
return chunks | |
def create_embeddings(chunks: List[Dict[str, Any]]) -> List[Dict[str, Any]]: | |
""" | |
Create embeddings for code chunks using OpenAI's embedding API. | |
Args: | |
chunks: List of code chunks | |
Returns: | |
List of chunks with embeddings added | |
""" | |
# Check if API key is set | |
if not (openai.api_key or os.environ.get("OPENAI_API_KEY")): | |
raise ValueError("OpenAI API key is not set. Please set it in the script or as an environment variable.") | |
# Process in batches to avoid rate limits | |
batch_size = 20 | |
all_embeddings = [] | |
for i in range(0, len(chunks), batch_size): | |
batch = chunks[i:i+batch_size] | |
# Prepare texts for embedding | |
texts = [f"{chunk['node_type']} {chunk['name']}: {chunk['code']}" for chunk in batch] | |
try: | |
# Generate embeddings using OpenAI's text-embedding-3-small model | |
response = openai.embeddings.create( | |
model="text-embedding-3-small", | |
input=texts, | |
dimensions=1536 # You can adjust this for different performance/cost tradeoffs | |
) | |
# Extract embeddings from response | |
embeddings = [item.embedding for item in response.data] | |
# Add embeddings to chunks | |
for j, embedding in enumerate(embeddings): | |
batch[j]['embedding'] = embedding | |
all_embeddings.extend(batch) | |
# Avoid rate limits | |
if i + batch_size < len(chunks): | |
time.sleep(0.5) | |
except Exception as e: | |
print(f"Error generating embeddings for batch {i//batch_size + 1}: {e}") | |
# Add chunks without embeddings | |
for chunk in batch: | |
chunk['embedding'] = [0] * 1536 # Add a zero vector as placeholder | |
all_embeddings.extend(batch) | |
return all_embeddings | |
def index_directory(directory: str, db_path: str = 'lancedb') -> str: | |
""" | |
Index all Python files in a directory and create a vector database. | |
Args: | |
directory: Directory containing Python files | |
db_path: Path to store the LanceDB database | |
Returns: | |
Path to the database | |
""" | |
# Find all Python files | |
py_files = glob.glob(f"{directory}/**/*.py", recursive=True) | |
print(f"Found {len(py_files)} Python files in {directory}") | |
# Extract code chunks | |
all_chunks = [] | |
for file_path in py_files: | |
chunks = extract_code_chunks(file_path) | |
all_chunks.extend(chunks) | |
print(f"Extracted {len(all_chunks)} code chunks") | |
# Create embeddings | |
chunks_with_embeddings = create_embeddings(all_chunks) | |
# Create LanceDB database | |
db = lancedb.connect(db_path) | |
# Create table | |
data = [{ | |
'id': i, | |
'file_path': chunk['file_path'], | |
'node_type': chunk['node_type'], | |
'name': chunk['name'], | |
'code': chunk['code'], | |
'start_line': chunk['start_line'], | |
'end_line': chunk['end_line'], | |
'vector': chunk['embedding'], # Changed the column name to 'vector' | |
} for i, chunk in enumerate(chunks_with_embeddings)] | |
# Create or overwrite table | |
table = db.create_table("code_chunks", data=data, mode="overwrite") | |
print(f"Created database at {db_path}") | |
return db_path | |
def search_code(query: str, db_path: str = 'lancedb', top_k: int = 5) -> List[Dict[str, Any]]: | |
""" | |
Search for code chunks relevant to a query. | |
Args: | |
query: Search query | |
db_path: Path to the LanceDB database | |
top_k: Number of results to return | |
Returns: | |
List of relevant code chunks | |
""" | |
# Check if API key is set | |
if not (openai.api_key or os.environ.get("OPENAI_API_KEY")): | |
raise ValueError("OpenAI API key is not set. Please set it in the script or as an environment variable.") | |
# Create query embedding | |
try: | |
response = openai.embeddings.create( | |
model="text-embedding-3-small", | |
input=[query], | |
dimensions=1536 | |
) | |
query_embedding = response.data[0].embedding | |
except Exception as e: | |
print(f"Error generating query embedding: {e}") | |
return [] | |
# Connect to database | |
db = lancedb.connect(db_path) | |
table = db.open_table("code_chunks") | |
# Search with explicit vector column name | |
results = table.search(query_embedding, vector_column_name="vector").limit(top_k).to_pandas() | |
# Convert to list of dictionaries | |
results_list = results.to_dict('records') | |
# Group by file_path to get top files | |
files = {} | |
for result in results_list: | |
file_path = result['file_path'] | |
if file_path not in files: | |
files[file_path] = { | |
'file_path': file_path, | |
'score': result['_distance'], | |
'chunks': [] | |
} | |
files[file_path]['chunks'].append({ | |
'node_type': result['node_type'], | |
'name': result['name'], | |
'start_line': result['start_line'], | |
'end_line': result['end_line'], | |
'score': result['_distance'] | |
}) | |
# Sort files by score (lower is better) | |
top_files = sorted(files.values(), key=lambda x: x['score'])[:top_k] | |
return top_files | |
def run_code_search(command: str, directory: str = None, query: str = None, db_path: str = 'lancedb', top_k: int = 5) -> List[Dict[str, Any]]: | |
""" | |
Run the code search tool programmatically. | |
Args: | |
command: The command to run ('index' or 'search') | |
directory: Directory to index (required for 'index' command) | |
query: Search query (required for 'search' command) | |
db_path: Path to the database | |
top_k: Number of results to return (for 'search' command) | |
Returns: | |
For 'index' command: Path to the database | |
For 'search' command: List of relevant code chunks | |
""" | |
if command == 'index': | |
if not directory: | |
raise ValueError("Directory is required for 'index' command") | |
return index_directory(directory, db_path) | |
elif command == 'search': | |
if not query: | |
raise ValueError("Query is required for 'search' command") | |
return search_code(query, db_path, top_k) | |
else: | |
raise ValueError(f"Unknown command: {command}") | |
def main(): | |
"""Main function to parse arguments and run the code search tool.""" | |
parser = argparse.ArgumentParser(description='Code Search Tool') | |
subparsers = parser.add_subparsers(dest='command', help='Command to run') | |
# Index command | |
index_parser = subparsers.add_parser('index', help='Index a directory of Python files') | |
index_parser.add_argument('directory', help='Directory to index') | |
index_parser.add_argument('--db-path', default='lancedb', help='Path to store the database') | |
# Search command | |
search_parser = subparsers.add_parser('search', help='Search for code') | |
search_parser.add_argument('query', help='Search query') | |
search_parser.add_argument('--db-path', default='lancedb', help='Path to the database') | |
search_parser.add_argument('--top-k', type=int, default=5, help='Number of results to return') | |
args = parser.parse_args() | |
if not args.command: | |
parser.print_help() | |
return | |
try: | |
if args.command == 'index': | |
run_code_search('index', directory=args.directory, db_path=args.db_path) | |
elif args.command == 'search': | |
top_files = run_code_search('search', query=args.query, db_path=args.db_path, top_k=args.top_k) | |
print(f"\nTop {len(top_files)} files for query: '{args.query}'") | |
for i, file in enumerate(top_files): | |
print(f"\n{i+1}. {file['file_path']} (score: {file['score']:.4f})") | |
for chunk in file['chunks']: | |
print(f" - {chunk['node_type']} '{chunk['name']}' (lines {chunk['start_line']}-{chunk['end_line']})") | |
except ValueError as e: | |
print(f"Error: {e}") | |
sys.exit(1) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment