|
# LaCo based model compression script by https://github.com/g023 (g023) - https://x.com/g023dev (X) |
|
|
|
import os |
|
import torch |
|
from unsloth import FastLanguageModel |
|
from typing import Any, Dict, Iterable, List, Sequence, cast |
|
|
|
# Configuration variables for model compression |
|
MODEL_PATH = "unsloth/gemma-3-1b-it-unsloth-bnb-4bit" # Path or HuggingFace model ID to load |
|
BASE_MODEL_ID = "unsloth/gemma-3-1b-it" # Base model ID for merging LoRA - the script will load the base model, apply the LoRA adapters, merge them in full precision, and then load the merged model in 4-bit quantization for pruning. This way, compression is applied to the fully fine-tuned model. |
|
OUTPUT_DIR = "compressed_model" # Directory to save the compressed model |
|
MAX_SEQ_LENGTH = 4096 # Maximum sequence length for the model |
|
PRUNING_SIMILARITY_THRESHOLD = 0.0 # Similarity threshold for layer merging (0.0 to 1.0) - increased for more conservative pruning |
|
PRUNING_MAX_MERGES = 3 # Maximum number of layer merges to perform - reduced for less aggressive pruning |
|
FREEZE_LAYERS = 10 # Number of base layers to freeze (not prune) - increased to preserve more knowledge |
|
CREATE_GGUF = True # Whether to create a GGUF file after compression |
|
GGUF_QUANTIZATION = "q8_0" # Quantization type for GGUF (e.g., "f16", "q8_0", "q4_k_m") |
|
TEST_PROMPT = "Hello, how are you?" # Prompt for quick inference test |
|
MERGE_LORA = True # Whether to merge LoRA adapters before pruning if they exist |
|
LOAD_QUANTIZED = False # Whether to load the model in 4-bit quantization for pruning |
|
|
|
# Potential improvements: |
|
# 1. Use activation-based similarity instead of parameter-based for better layer matching. |
|
# 2. Implement optimal transport or Fisher-weighted merging for better parameter combination. |
|
# 3. Add post-merging fine-tuning or distillation to recover performance. |
|
# 4. Use clustering algorithms to merge multiple layers at once. |
|
# 5. Incorporate importance scores (e.g., from gradients) to decide merge order. |
|
# 6. Combine with other compression techniques like quantization-aware pruning. |
|
|
|
# LaCo-inspired Layer Collapse Implementation |
|
def compute_layer_similarity(layer1, layer2): |
|
"""Compute cosine similarity between two layers using key sub-modules.""" |
|
import torch.nn.functional as F |
|
|
|
# Focus on attention and MLP parameters for similarity |
|
params1 = [] |
|
params2 = [] |
|
|
|
# Collect parameters from self_attn and mlp |
|
for name, module in layer1.named_modules(): |
|
if 'self_attn' in name or 'mlp' in name: |
|
for p in module.parameters(): |
|
params1.append(p.float().flatten()) |
|
|
|
for name, module in layer2.named_modules(): |
|
if 'self_attn' in name or 'mlp' in name: |
|
for p in module.parameters(): |
|
params2.append(p.float().flatten()) |
|
|
|
if not params1 or not params2: |
|
# Fallback to all parameters if sub-modules not found |
|
params1 = [p.float().flatten() for p in layer1.parameters()] |
|
params2 = [p.float().flatten() for p in layer2.parameters()] |
|
|
|
# Concatenate |
|
w1 = torch.cat(params1) |
|
w2 = torch.cat(params2) |
|
|
|
return F.cosine_similarity(w1.unsqueeze(0), w2.unsqueeze(0)).item() |
|
|
|
def evaluate_model(model, tokenizer, text="The quick brown fox jumps over the lazy dog."): |
|
"""Simple evaluation using perplexity on a short text.""" |
|
inputs = tokenizer(text, return_tensors="pt").to(model.device) |
|
with torch.no_grad(): |
|
outputs = model(**inputs, labels=inputs["input_ids"]) |
|
loss = outputs.loss |
|
perplexity = torch.exp(loss).item() |
|
return perplexity |
|
|
|
def merge_layers(layer1, layer2, alpha=0.5): |
|
"""Merge two layers by weighted average.""" |
|
with torch.no_grad(): |
|
for p1, p2 in zip(layer1.parameters(), layer2.parameters()): |
|
p1.copy_(alpha * p1 + (1 - alpha) * p2) |
|
|
|
def apply_layer_collapse(model, similarity_threshold=0.85, max_merges=5, freeze_layers=0): |
|
"""Apply layer collapse to reduce model depth, prioritizing best matching layers.""" |
|
print(f"Applying layer collapse with threshold {similarity_threshold}, max merges {max_merges}, freezing first {freeze_layers} layers") |
|
|
|
layers = model.model.layers |
|
initial_num_layers = len(layers) |
|
merges_done = 0 |
|
all_similarities = [] # Collect all similarity scores |
|
|
|
while merges_done < max_merges: |
|
# Compute similarities for all possible pairs starting from freeze_layers |
|
similarities = [] |
|
for i in range(freeze_layers, len(layers) - 1): |
|
sim = compute_layer_similarity(layers[i], layers[i + 1]) |
|
similarities.append((sim, i)) |
|
all_similarities.append(sim) |
|
|
|
if not similarities: |
|
break |
|
|
|
# Find the pair with the highest similarity above threshold |
|
best_sim, best_i = max(similarities, key=lambda x: x[0]) |
|
|
|
if best_sim < similarity_threshold: |
|
print(f"No more pairs above threshold {similarity_threshold}. Stopping.") |
|
break |
|
|
|
print(f"Best similarity: {best_sim:.4f} between layers {best_i} and {best_i+1}") |
|
print(f"Merging layers {best_i} and {best_i+1}") |
|
|
|
# Merge the best pair |
|
merge_layers(layers[best_i], layers[best_i + 1]) |
|
# Remove the second layer |
|
layers.pop(best_i + 1) |
|
merges_done += 1 |
|
|
|
final_num_layers = len(layers) |
|
print(f"Reduced layers from {initial_num_layers} to {final_num_layers}") |
|
|
|
# Output the range of detected similarity scores |
|
if all_similarities: |
|
min_sim = min(all_similarities) |
|
max_sim = max(all_similarities) |
|
print(f"Similarity score range: {min_sim:.4f} to {max_sim:.4f}") |
|
else: |
|
print("No similarities computed.") |
|
|
|
# Update model config |
|
model.config.num_hidden_layers = final_num_layers |
|
|
|
return model |
|
|
|
def main(): |
|
print(f"Loading model from {MODEL_PATH}") |
|
|
|
# Check if we need to merge LoRA |
|
adapter_path = os.path.join(MODEL_PATH, "adapter_model.safetensors") |
|
if MERGE_LORA and os.path.exists(adapter_path): |
|
print("Merging LoRA adapters before compression...") |
|
|
|
# Load base model in full precision |
|
base_model, base_tokenizer = FastLanguageModel.from_pretrained( |
|
model_name=BASE_MODEL_ID, |
|
max_seq_length=MAX_SEQ_LENGTH, |
|
load_in_4bit=False, |
|
load_in_8bit=False, |
|
attn_implementation="flash_attention_2", |
|
rope_scaling={"type": "dynamic", "factor": 2.0}, |
|
use_gradient_checkpointing="unsloth", |
|
device_map="auto", |
|
) |
|
|
|
# Apply LoRA |
|
base_model = FastLanguageModel.get_peft_model( |
|
base_model, |
|
r=8, # Assuming rank 8, adjust if different |
|
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], |
|
lora_alpha=16, |
|
lora_dropout=0, |
|
bias="none", |
|
use_gradient_checkpointing="unsloth", |
|
random_state=3407, |
|
use_rslora=False, |
|
loftq_config=None, |
|
) |
|
|
|
# Load the trained LoRA adapters |
|
base_model.load_adapter(MODEL_PATH, adapter_name="default") |
|
|
|
# Merge |
|
merged_model = base_model.merge_and_unload() |
|
|
|
# Save temporary merged model |
|
temp_dir = OUTPUT_DIR + "_temp_merged" |
|
os.makedirs(temp_dir, exist_ok=True) |
|
merged_model.save_pretrained(temp_dir, safe_serialization=True) |
|
base_tokenizer.save_pretrained(temp_dir) |
|
|
|
# Load the merged model in 4-bit |
|
model, tokenizer = FastLanguageModel.from_pretrained( |
|
model_name=temp_dir, |
|
max_seq_length=MAX_SEQ_LENGTH, |
|
load_in_4bit=LOAD_QUANTIZED, |
|
attn_implementation="flash_attention_2", |
|
rope_scaling={"type": "dynamic", "factor": 2.0}, |
|
use_gradient_checkpointing="unsloth", |
|
device_map="auto", |
|
float8_kv_cache=True, |
|
) |
|
|
|
# Clean up temp |
|
import shutil |
|
shutil.rmtree(temp_dir) |
|
else: |
|
# Load the model directly |
|
model, tokenizer = FastLanguageModel.from_pretrained( |
|
model_name=MODEL_PATH, |
|
max_seq_length=MAX_SEQ_LENGTH, |
|
load_in_4bit=LOAD_QUANTIZED, |
|
attn_implementation="flash_attention_2", |
|
rope_scaling={"type": "dynamic", "factor": 2.0}, |
|
use_gradient_checkpointing="unsloth", |
|
device_map="auto", |
|
float8_kv_cache=True, |
|
) |
|
|
|
# Quick inference test before pruning |
|
print("Testing inference before pruning...") |
|
inputs = tokenizer(TEST_PROMPT, return_tensors="pt").to(model.device) |
|
with torch.no_grad(): |
|
outputs = model.generate(**inputs, max_new_tokens=50, do_sample=False, pad_token_id=tokenizer.eos_token_id) |
|
generated_before = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
print(f"Before pruning: {generated_before}") |
|
|
|
# Evaluate perplexity |
|
perplexity_before = evaluate_model(model, tokenizer, TEST_PROMPT) |
|
print(f"Perplexity before pruning: {perplexity_before:.4f}") |
|
|
|
# Count initial parameters |
|
initial_params = sum(p.numel() for p in model.parameters()) |
|
print(f"Initial parameters: {initial_params:,}") |
|
|
|
# Apply layer collapse pruning |
|
model = apply_layer_collapse( |
|
model, |
|
similarity_threshold=PRUNING_SIMILARITY_THRESHOLD, |
|
max_merges=PRUNING_MAX_MERGES, |
|
freeze_layers=FREEZE_LAYERS |
|
) |
|
|
|
# Count parameters after pruning |
|
final_params = sum(p.numel() for p in model.parameters()) |
|
print(f"Parameters after pruning: {final_params:,}") |
|
|
|
# Evaluate perplexity after |
|
perplexity_after = evaluate_model(model, tokenizer, TEST_PROMPT) |
|
print(f"Perplexity after pruning: {perplexity_after:.4f}") |
|
|
|
# Quick inference test after pruning |
|
print("Testing inference after pruning...") |
|
inputs = tokenizer(TEST_PROMPT, return_tensors="pt").to(model.device) |
|
with torch.no_grad(): |
|
outputs = model.generate(**inputs, max_new_tokens=50, do_sample=False, pad_token_id=tokenizer.eos_token_id) |
|
generated_after = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
print(f"After pruning: {generated_after}") |
|
|
|
# Create output directory if it doesn't exist |
|
os.makedirs(OUTPUT_DIR, exist_ok=True) |
|
|
|
# Save the compressed model |
|
print(f"Saving compressed model to {OUTPUT_DIR}") |
|
model.save_pretrained(OUTPUT_DIR, safe_serialization=True) |
|
tokenizer.save_pretrained(OUTPUT_DIR) |
|
|
|
# Optionally create GGUF |
|
if CREATE_GGUF: |
|
# Check if model has LoRA adapters (PEFT model) |
|
has_lora = hasattr(model, 'peft_config') and model.peft_config is not None |
|
if has_lora: |
|
print(f"Creating GGUF file with quantization {GGUF_QUANTIZATION}") |
|
model.save_pretrained_gguf(OUTPUT_DIR + "_gguf", tokenizer, quantization_method=GGUF_QUANTIZATION) |
|
else: |
|
print("Model is not a PEFT model (no LoRA adapters). Skipping GGUF creation as it requires PEFT format.") |
|
print("To create GGUF, ensure the model has LoRA adapters or use external conversion tools.") |
|
|
|
print("Model compression completed successfully!") |
|
|
|
if __name__ == "__main__": |
|
main() |