Skip to content

Instantly share code, notes, and snippets.

@g023
Last active November 9, 2025 20:07
Show Gist options
  • Select an option

  • Save g023/4d9dc3c4e35176b889329174f62ee6b5 to your computer and use it in GitHub Desktop.

Select an option

Save g023/4d9dc3c4e35176b889329174f62ee6b5 to your computer and use it in GitHub Desktop.
LaCo (Layer Collapse) Inspired Model Compressor

Model Compression with Layer Collapse

A Python script for compressing large language models using layer merging techniques inspired by LaCo (Layer Collapse). This tool reduces model depth by identifying and merging highly similar transformer layers, significantly decreasing parameter count while maintaining inference quality.

Key Features

LoRA Adapter Merging: Seamlessly integrates fine-tuned LoRA adapters into the base model before compression Similarity-Based Pruning: Computes cosine similarity between adjacent layers, merging the most redundant pairs above a configurable threshold Configurable Compression: Adjustable similarity thresholds, maximum merges, and frozen layers for controlled compression Performance Evaluation: Includes perplexity calculation and inference testing before/after compression GGUF Export: Optional creation of quantized GGUF files for efficient deployment Unsloth Integration: Leverages Unsloth's optimized transformers for fast model loading and processing

Usage

Configure the model paths, thresholds, and options at the top of the script, then run:

python model_compress.py
# LaCo based model compression script by https://github.com/g023 (g023) - https://x.com/g023dev (X)
import os
import torch
from unsloth import FastLanguageModel
from typing import Any, Dict, Iterable, List, Sequence, cast
# Configuration variables for model compression
MODEL_PATH = "unsloth/gemma-3-1b-it-unsloth-bnb-4bit" # Path or HuggingFace model ID to load
BASE_MODEL_ID = "unsloth/gemma-3-1b-it" # Base model ID for merging LoRA - the script will load the base model, apply the LoRA adapters, merge them in full precision, and then load the merged model in 4-bit quantization for pruning. This way, compression is applied to the fully fine-tuned model.
OUTPUT_DIR = "compressed_model" # Directory to save the compressed model
MAX_SEQ_LENGTH = 4096 # Maximum sequence length for the model
PRUNING_SIMILARITY_THRESHOLD = 0.0 # Similarity threshold for layer merging (0.0 to 1.0) - increased for more conservative pruning
PRUNING_MAX_MERGES = 3 # Maximum number of layer merges to perform - reduced for less aggressive pruning
FREEZE_LAYERS = 10 # Number of base layers to freeze (not prune) - increased to preserve more knowledge
CREATE_GGUF = True # Whether to create a GGUF file after compression
GGUF_QUANTIZATION = "q8_0" # Quantization type for GGUF (e.g., "f16", "q8_0", "q4_k_m")
TEST_PROMPT = "Hello, how are you?" # Prompt for quick inference test
MERGE_LORA = True # Whether to merge LoRA adapters before pruning if they exist
LOAD_QUANTIZED = False # Whether to load the model in 4-bit quantization for pruning
# Potential improvements:
# 1. Use activation-based similarity instead of parameter-based for better layer matching.
# 2. Implement optimal transport or Fisher-weighted merging for better parameter combination.
# 3. Add post-merging fine-tuning or distillation to recover performance.
# 4. Use clustering algorithms to merge multiple layers at once.
# 5. Incorporate importance scores (e.g., from gradients) to decide merge order.
# 6. Combine with other compression techniques like quantization-aware pruning.
# LaCo-inspired Layer Collapse Implementation
def compute_layer_similarity(layer1, layer2):
"""Compute cosine similarity between two layers using key sub-modules."""
import torch.nn.functional as F
# Focus on attention and MLP parameters for similarity
params1 = []
params2 = []
# Collect parameters from self_attn and mlp
for name, module in layer1.named_modules():
if 'self_attn' in name or 'mlp' in name:
for p in module.parameters():
params1.append(p.float().flatten())
for name, module in layer2.named_modules():
if 'self_attn' in name or 'mlp' in name:
for p in module.parameters():
params2.append(p.float().flatten())
if not params1 or not params2:
# Fallback to all parameters if sub-modules not found
params1 = [p.float().flatten() for p in layer1.parameters()]
params2 = [p.float().flatten() for p in layer2.parameters()]
# Concatenate
w1 = torch.cat(params1)
w2 = torch.cat(params2)
return F.cosine_similarity(w1.unsqueeze(0), w2.unsqueeze(0)).item()
def evaluate_model(model, tokenizer, text="The quick brown fox jumps over the lazy dog."):
"""Simple evaluation using perplexity on a short text."""
inputs = tokenizer(text, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model(**inputs, labels=inputs["input_ids"])
loss = outputs.loss
perplexity = torch.exp(loss).item()
return perplexity
def merge_layers(layer1, layer2, alpha=0.5):
"""Merge two layers by weighted average."""
with torch.no_grad():
for p1, p2 in zip(layer1.parameters(), layer2.parameters()):
p1.copy_(alpha * p1 + (1 - alpha) * p2)
def apply_layer_collapse(model, similarity_threshold=0.85, max_merges=5, freeze_layers=0):
"""Apply layer collapse to reduce model depth, prioritizing best matching layers."""
print(f"Applying layer collapse with threshold {similarity_threshold}, max merges {max_merges}, freezing first {freeze_layers} layers")
layers = model.model.layers
initial_num_layers = len(layers)
merges_done = 0
all_similarities = [] # Collect all similarity scores
while merges_done < max_merges:
# Compute similarities for all possible pairs starting from freeze_layers
similarities = []
for i in range(freeze_layers, len(layers) - 1):
sim = compute_layer_similarity(layers[i], layers[i + 1])
similarities.append((sim, i))
all_similarities.append(sim)
if not similarities:
break
# Find the pair with the highest similarity above threshold
best_sim, best_i = max(similarities, key=lambda x: x[0])
if best_sim < similarity_threshold:
print(f"No more pairs above threshold {similarity_threshold}. Stopping.")
break
print(f"Best similarity: {best_sim:.4f} between layers {best_i} and {best_i+1}")
print(f"Merging layers {best_i} and {best_i+1}")
# Merge the best pair
merge_layers(layers[best_i], layers[best_i + 1])
# Remove the second layer
layers.pop(best_i + 1)
merges_done += 1
final_num_layers = len(layers)
print(f"Reduced layers from {initial_num_layers} to {final_num_layers}")
# Output the range of detected similarity scores
if all_similarities:
min_sim = min(all_similarities)
max_sim = max(all_similarities)
print(f"Similarity score range: {min_sim:.4f} to {max_sim:.4f}")
else:
print("No similarities computed.")
# Update model config
model.config.num_hidden_layers = final_num_layers
return model
def main():
print(f"Loading model from {MODEL_PATH}")
# Check if we need to merge LoRA
adapter_path = os.path.join(MODEL_PATH, "adapter_model.safetensors")
if MERGE_LORA and os.path.exists(adapter_path):
print("Merging LoRA adapters before compression...")
# Load base model in full precision
base_model, base_tokenizer = FastLanguageModel.from_pretrained(
model_name=BASE_MODEL_ID,
max_seq_length=MAX_SEQ_LENGTH,
load_in_4bit=False,
load_in_8bit=False,
attn_implementation="flash_attention_2",
rope_scaling={"type": "dynamic", "factor": 2.0},
use_gradient_checkpointing="unsloth",
device_map="auto",
)
# Apply LoRA
base_model = FastLanguageModel.get_peft_model(
base_model,
r=8, # Assuming rank 8, adjust if different
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
lora_alpha=16,
lora_dropout=0,
bias="none",
use_gradient_checkpointing="unsloth",
random_state=3407,
use_rslora=False,
loftq_config=None,
)
# Load the trained LoRA adapters
base_model.load_adapter(MODEL_PATH, adapter_name="default")
# Merge
merged_model = base_model.merge_and_unload()
# Save temporary merged model
temp_dir = OUTPUT_DIR + "_temp_merged"
os.makedirs(temp_dir, exist_ok=True)
merged_model.save_pretrained(temp_dir, safe_serialization=True)
base_tokenizer.save_pretrained(temp_dir)
# Load the merged model in 4-bit
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=temp_dir,
max_seq_length=MAX_SEQ_LENGTH,
load_in_4bit=LOAD_QUANTIZED,
attn_implementation="flash_attention_2",
rope_scaling={"type": "dynamic", "factor": 2.0},
use_gradient_checkpointing="unsloth",
device_map="auto",
float8_kv_cache=True,
)
# Clean up temp
import shutil
shutil.rmtree(temp_dir)
else:
# Load the model directly
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=MODEL_PATH,
max_seq_length=MAX_SEQ_LENGTH,
load_in_4bit=LOAD_QUANTIZED,
attn_implementation="flash_attention_2",
rope_scaling={"type": "dynamic", "factor": 2.0},
use_gradient_checkpointing="unsloth",
device_map="auto",
float8_kv_cache=True,
)
# Quick inference test before pruning
print("Testing inference before pruning...")
inputs = tokenizer(TEST_PROMPT, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(**inputs, max_new_tokens=50, do_sample=False, pad_token_id=tokenizer.eos_token_id)
generated_before = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"Before pruning: {generated_before}")
# Evaluate perplexity
perplexity_before = evaluate_model(model, tokenizer, TEST_PROMPT)
print(f"Perplexity before pruning: {perplexity_before:.4f}")
# Count initial parameters
initial_params = sum(p.numel() for p in model.parameters())
print(f"Initial parameters: {initial_params:,}")
# Apply layer collapse pruning
model = apply_layer_collapse(
model,
similarity_threshold=PRUNING_SIMILARITY_THRESHOLD,
max_merges=PRUNING_MAX_MERGES,
freeze_layers=FREEZE_LAYERS
)
# Count parameters after pruning
final_params = sum(p.numel() for p in model.parameters())
print(f"Parameters after pruning: {final_params:,}")
# Evaluate perplexity after
perplexity_after = evaluate_model(model, tokenizer, TEST_PROMPT)
print(f"Perplexity after pruning: {perplexity_after:.4f}")
# Quick inference test after pruning
print("Testing inference after pruning...")
inputs = tokenizer(TEST_PROMPT, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(**inputs, max_new_tokens=50, do_sample=False, pad_token_id=tokenizer.eos_token_id)
generated_after = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"After pruning: {generated_after}")
# Create output directory if it doesn't exist
os.makedirs(OUTPUT_DIR, exist_ok=True)
# Save the compressed model
print(f"Saving compressed model to {OUTPUT_DIR}")
model.save_pretrained(OUTPUT_DIR, safe_serialization=True)
tokenizer.save_pretrained(OUTPUT_DIR)
# Optionally create GGUF
if CREATE_GGUF:
# Check if model has LoRA adapters (PEFT model)
has_lora = hasattr(model, 'peft_config') and model.peft_config is not None
if has_lora:
print(f"Creating GGUF file with quantization {GGUF_QUANTIZATION}")
model.save_pretrained_gguf(OUTPUT_DIR + "_gguf", tokenizer, quantization_method=GGUF_QUANTIZATION)
else:
print("Model is not a PEFT model (no LoRA adapters). Skipping GGUF creation as it requires PEFT format.")
print("To create GGUF, ensure the model has LoRA adapters or use external conversion tools.")
print("Model compression completed successfully!")
if __name__ == "__main__":
main()
@g023
Copy link
Copy Markdown
Author

g023 commented Nov 9, 2025

model_compress

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment