Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save farleyknight/454f9e1d96842bdfa8044a57950f14b0 to your computer and use it in GitHub Desktop.
Save farleyknight/454f9e1d96842bdfa8044a57950f14b0 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
"""
Code Search Tool - A simple command-line tool to search Python code semantically.
This tool indexes Python files in a directory, creates embeddings, and returns the top 5
most relevant files for a given query.
"""
import os
import sys
import argparse
import glob
import ast
import re
from typing import List, Dict, Tuple, Any
import lancedb
import numpy as np
import openai
import time
import json
import dotenv
dotenv.load_dotenv()
# You'll need to set your OpenAI API key
# openai.api_key = "your-api-key-here"
# Or set it as an environment variable:
# os.environ["OPENAI_API_KEY"] = "your-api-key-here"
def extract_code_chunks(file_path: str) -> List[Dict[str, Any]]:
"""
Extract code chunks from a Python file using Python's AST module.
Args:
file_path: Path to the Python file
Returns:
List of dictionaries containing code chunks with metadata
"""
try:
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
lines = content.split('\n')
except Exception as e:
print(f"Error reading {file_path}: {e}")
return []
chunks = []
try:
# Parse the Python code
tree = ast.parse(content)
# Extract classes and functions
for node in ast.walk(tree):
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
# Get node type
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
node_type = 'function'
else:
node_type = 'class'
# Get name
name = node.name
# Get line numbers
start_line = node.lineno
end_line = 0
# Find the end line by looking at the last child's line number
for child in ast.walk(node):
if hasattr(child, 'lineno'):
end_line = max(end_line, child.lineno)
# If we couldn't determine the end line, estimate it
if end_line < start_line:
end_line = start_line + 1
# Get the code as text
node_text = '\n'.join(lines[start_line-1:end_line])
# Skip if too short
if len(node_text.strip()) < 10:
continue
# Add to chunks
chunks.append({
'file_path': file_path,
'node_type': node_type,
'name': name,
'code': node_text,
'start_line': start_line,
'end_line': end_line,
})
except SyntaxError as e:
print(f"Syntax error in {file_path}: {e}")
# If no chunks found (e.g., script with no functions/classes), add the whole file
if not chunks:
chunks.append({
'file_path': file_path,
'node_type': 'script',
'name': os.path.basename(file_path),
'code': content,
'start_line': 1,
'end_line': len(lines),
})
return chunks
def create_embeddings(chunks: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Create embeddings for code chunks using OpenAI's embedding API.
Args:
chunks: List of code chunks
Returns:
List of chunks with embeddings added
"""
# Check if API key is set
if not (openai.api_key or os.environ.get("OPENAI_API_KEY")):
raise ValueError("OpenAI API key is not set. Please set it in the script or as an environment variable.")
# Process in batches to avoid rate limits
batch_size = 20
all_embeddings = []
for i in range(0, len(chunks), batch_size):
batch = chunks[i:i+batch_size]
# Prepare texts for embedding
texts = [f"{chunk['node_type']} {chunk['name']}: {chunk['code']}" for chunk in batch]
try:
# Generate embeddings using OpenAI's text-embedding-3-small model
response = openai.embeddings.create(
model="text-embedding-3-small",
input=texts,
dimensions=1536 # You can adjust this for different performance/cost tradeoffs
)
# Extract embeddings from response
embeddings = [item.embedding for item in response.data]
# Add embeddings to chunks
for j, embedding in enumerate(embeddings):
batch[j]['embedding'] = embedding
all_embeddings.extend(batch)
# Avoid rate limits
if i + batch_size < len(chunks):
time.sleep(0.5)
except Exception as e:
print(f"Error generating embeddings for batch {i//batch_size + 1}: {e}")
# Add chunks without embeddings
for chunk in batch:
chunk['embedding'] = [0] * 1536 # Add a zero vector as placeholder
all_embeddings.extend(batch)
return all_embeddings
def index_directory(directory: str, db_path: str = 'lancedb') -> str:
"""
Index all Python files in a directory and create a vector database.
Args:
directory: Directory containing Python files
db_path: Path to store the LanceDB database
Returns:
Path to the database
"""
# Find all Python files
py_files = glob.glob(f"{directory}/**/*.py", recursive=True)
print(f"Found {len(py_files)} Python files in {directory}")
# Extract code chunks
all_chunks = []
for file_path in py_files:
chunks = extract_code_chunks(file_path)
all_chunks.extend(chunks)
print(f"Extracted {len(all_chunks)} code chunks")
# Create embeddings
chunks_with_embeddings = create_embeddings(all_chunks)
# Create LanceDB database
db = lancedb.connect(db_path)
# Create table
data = [{
'id': i,
'file_path': chunk['file_path'],
'node_type': chunk['node_type'],
'name': chunk['name'],
'code': chunk['code'],
'start_line': chunk['start_line'],
'end_line': chunk['end_line'],
'vector': chunk['embedding'], # Changed the column name to 'vector'
} for i, chunk in enumerate(chunks_with_embeddings)]
# Create or overwrite table
table = db.create_table("code_chunks", data=data, mode="overwrite")
print(f"Created database at {db_path}")
return db_path
def search_code(query: str, db_path: str = 'lancedb', top_k: int = 5) -> List[Dict[str, Any]]:
"""
Search for code chunks relevant to a query.
Args:
query: Search query
db_path: Path to the LanceDB database
top_k: Number of results to return
Returns:
List of relevant code chunks
"""
# Check if API key is set
if not (openai.api_key or os.environ.get("OPENAI_API_KEY")):
raise ValueError("OpenAI API key is not set. Please set it in the script or as an environment variable.")
# Create query embedding
try:
response = openai.embeddings.create(
model="text-embedding-3-small",
input=[query],
dimensions=1536
)
query_embedding = response.data[0].embedding
except Exception as e:
print(f"Error generating query embedding: {e}")
return []
# Connect to database
db = lancedb.connect(db_path)
table = db.open_table("code_chunks")
# Search with explicit vector column name
results = table.search(query_embedding, vector_column_name="vector").limit(top_k).to_pandas()
# Convert to list of dictionaries
results_list = results.to_dict('records')
# Group by file_path to get top files
files = {}
for result in results_list:
file_path = result['file_path']
if file_path not in files:
files[file_path] = {
'file_path': file_path,
'score': result['_distance'],
'chunks': []
}
files[file_path]['chunks'].append({
'node_type': result['node_type'],
'name': result['name'],
'start_line': result['start_line'],
'end_line': result['end_line'],
'score': result['_distance']
})
# Sort files by score (lower is better)
top_files = sorted(files.values(), key=lambda x: x['score'])[:top_k]
return top_files
def run_code_search(command: str, directory: str = None, query: str = None, db_path: str = 'lancedb', top_k: int = 5) -> List[Dict[str, Any]]:
"""
Run the code search tool programmatically.
Args:
command: The command to run ('index' or 'search')
directory: Directory to index (required for 'index' command)
query: Search query (required for 'search' command)
db_path: Path to the database
top_k: Number of results to return (for 'search' command)
Returns:
For 'index' command: Path to the database
For 'search' command: List of relevant code chunks
"""
if command == 'index':
if not directory:
raise ValueError("Directory is required for 'index' command")
return index_directory(directory, db_path)
elif command == 'search':
if not query:
raise ValueError("Query is required for 'search' command")
return search_code(query, db_path, top_k)
else:
raise ValueError(f"Unknown command: {command}")
def main():
"""Main function to parse arguments and run the code search tool."""
parser = argparse.ArgumentParser(description='Code Search Tool')
subparsers = parser.add_subparsers(dest='command', help='Command to run')
# Index command
index_parser = subparsers.add_parser('index', help='Index a directory of Python files')
index_parser.add_argument('directory', help='Directory to index')
index_parser.add_argument('--db-path', default='lancedb', help='Path to store the database')
# Search command
search_parser = subparsers.add_parser('search', help='Search for code')
search_parser.add_argument('query', help='Search query')
search_parser.add_argument('--db-path', default='lancedb', help='Path to the database')
search_parser.add_argument('--top-k', type=int, default=5, help='Number of results to return')
args = parser.parse_args()
if not args.command:
parser.print_help()
return
try:
if args.command == 'index':
run_code_search('index', directory=args.directory, db_path=args.db_path)
elif args.command == 'search':
top_files = run_code_search('search', query=args.query, db_path=args.db_path, top_k=args.top_k)
print(f"\nTop {len(top_files)} files for query: '{args.query}'")
for i, file in enumerate(top_files):
print(f"\n{i+1}. {file['file_path']} (score: {file['score']:.4f})")
for chunk in file['chunks']:
print(f" - {chunk['node_type']} '{chunk['name']}' (lines {chunk['start_line']}-{chunk['end_line']})")
except ValueError as e:
print(f"Error: {e}")
sys.exit(1)
if __name__ == '__main__':
main()
@farleyknight
Copy link
Author

Screenshot 2025-04-04 at 1 36 43 PM

@farleyknight
Copy link
Author

Editor _ Mermaid Chart-2025-04-04-174230

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment