Created
February 21, 2023 06:21
-
-
Save briansemrau/c68835edc88a0dea79b092bce4e1ee17 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# This script was adapted from merge.py from the KoboldAI discord server. | |
# I believe the original author is concedo | |
import os | |
import gc | |
import json | |
import shutil | |
import resource | |
import torch | |
from itertools import zip_longest | |
diff_weight = 0.6#1.0 | |
assert(diff_weight > 0.0 and diff_weight <= 1.0) # disable if you are brave | |
model_0_folder = 'gpt-j-6B-shardfp16' # base model | |
model_1_folder = 'gpt-jt-6B-v1-shardfp16' | |
model_2_folder = 'ppo_hh_gpt-j-shardfp16' | |
merged_model_folder = 'gpt-r-diff_0.6-6B' | |
# output = A + (B - C) * diff_weight | |
# A: model_1 | |
# B: model_2 (compare model) | |
# C: model_0 (base model) | |
torch_map_location = 'cpu' | |
if (os.path.exists(merged_model_folder)): | |
if len(os.listdir(merged_model_folder)) != 0: | |
raise Exception(f'Non empty directory "{merged_model_folder}" already exists') | |
#print(f"[*] Merging models\n\t({round(model_1_ratio * 100, 2)} %) {model_1_folder.split('/')[-1]}\n\t({round(model_2_ratio * 100, 2)} %) {model_2_folder.split('/')[-1]}\n") | |
def format_size(num, suffix="B"): | |
for unit in ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"]: | |
if abs(num) < 1024.0: | |
return f"{num:3.1f}{unit}{suffix}" | |
num /= 1024.0 | |
return f"{num:.1f}Yi{suffix}" | |
model_0_files = [file for file in os.listdir(model_1_folder) if file.endswith('.bin')] | |
model_1_files = [file for file in os.listdir(model_1_folder) if file.endswith('.bin')] | |
model_2_files = [file for file in os.listdir(model_2_folder) if file.endswith('.bin')] | |
model_0_files.sort() | |
model_1_files.sort() | |
model_2_files.sort() | |
max_files_length = len(model_0_files) | |
if len(model_1_files) > max_files_length: | |
max_files_length = len(model_1_files) | |
if len(model_2_files) > max_files_length: | |
max_files_length = len(model_2_files) | |
model_files = [file for file in zip_longest(model_0_files, model_1_files, model_2_files)] | |
model_size_bytes = 0 | |
for file in model_0_files: | |
model_size_bytes += os.path.getsize(f'{model_0_folder}/{file}') | |
bin_weight_map = {} | |
backlog_layers = {"model_1": {}, "model_2": {}, "model_02": {}, "model_01": {}} | |
folder_created = False | |
for model_file_idx, model_file in enumerate(model_files): | |
print(f'-- {model_file_idx + 1} / {len(model_files)} --') | |
diff_model = {} | |
merged_model = {} | |
model_file_idx += 1 | |
model_0_file, model_1_file, model_2_file = model_file | |
model_0_layers = {} | |
model_1_layers = {} | |
model_2_layers = {} | |
if model_0_file is not None: | |
print('[*] Reading', f"{model_0_folder}/{model_0_file}") | |
model_0_layers = torch.load(f"{model_0_folder}/{model_0_file}", map_location=torch_map_location, weights_only=True) | |
if model_2_file is not None: | |
print('[*] Reading', f"{model_2_folder}/{model_2_file}") | |
model_2_layers = torch.load(f"{model_2_folder}/{model_2_file}", map_location=torch_map_location, weights_only=True) | |
model_0_layers.update(backlog_layers['model_02']) | |
model_2_layers.update(backlog_layers['model_2']) | |
backlog_layers['model_02'] = {} | |
backlog_layers['model_01'] = {} | |
backlog_layers['model_1'] = {} | |
backlog_layers['model_2'] = {} | |
# Diff | |
for backlog_layer in set(model_0_layers).symmetric_difference(set(model_2_layers)): | |
if backlog_layer in model_0_layers: | |
backlog_layers['model_02'][backlog_layer] = model_0_layers[backlog_layer] | |
if backlog_layer in model_2_layers: | |
backlog_layers['model_2'][backlog_layer] = model_2_layers[backlog_layer] | |
for common_layer in set(model_0_layers).intersection(set(model_2_layers)): | |
w_model_0 = model_0_layers[common_layer] | |
w_model_2 = model_2_layers[common_layer] | |
diff_model[common_layer] = (w_model_2 - w_model_0) * diff_weight if diff_weight != 1.0 else w_model_2 - w_model_0 | |
del model_2_layers | |
gc.collect() | |
# Merge | |
if model_1_file is not None: | |
print('[*] Reading', f"{model_1_folder}/{model_1_file}") | |
model_1_layers = torch.load(f"{model_1_folder}/{model_1_file}", map_location=torch_map_location, weights_only=True) | |
model_0_layers.update(backlog_layers['model_01']) | |
model_1_layers.update(backlog_layers['model_1']) | |
for backlog_layer in set(model_1_layers).symmetric_difference(set(model_0_layers)): | |
if backlog_layer in model_1_layers: | |
backlog_layers['model_1'][backlog_layer] = model_1_layers[backlog_layer] | |
if backlog_layer in model_0_layers: | |
backlog_layers['model_01'][backlog_layer] = model_0_layers[backlog_layer] | |
del model_0_layers | |
gc.collect() | |
for common_layer in set(model_1_layers).intersection(set(diff_model)): | |
w_model_1 = model_1_layers[common_layer] | |
w_model_diff = diff_model[common_layer] | |
merged_model[common_layer] = w_model_1 + w_model_diff | |
bin_weight_map[common_layer] = f'pytorch_model-{(model_file_idx):05}-of-{max_files_length:05}.bin' | |
if not folder_created: | |
os.makedirs(merged_model_folder, exist_ok=True) | |
for file_to_copy in [file for file in os.listdir(model_0_folder) if (file.endswith('.json') or file.endswith('.txt')) and not file.endswith('.index.json')]: | |
shutil.copy(f'{model_0_folder}/{file_to_copy}', merged_model_folder) | |
folder_created = True | |
if len(model_files) == 1: | |
print(f'[*] Saving model: {merged_model_folder}/pytorch_model.bin') | |
torch.save(merged_model, f'{merged_model_folder}/pytorch_model.bin') | |
else: | |
print(f'[*] Saving shard: {merged_model_folder}/pytorch_model-{(model_file_idx):05}-of-{max_files_length:05}.bin') | |
torch.save(merged_model, f'{merged_model_folder}/pytorch_model-{(model_file_idx):05}-of-{max_files_length:05}.bin') | |
print('[*] Memory used:', format_size(resource.getrusage(resource.RUSAGE_SELF).ru_maxrss * 1024)) | |
del model_1_layers | |
gc.collect() | |
if len(model_files) > 1: | |
print(f'[*] Saving bin weight map:', f'{merged_model_folder}/pytorch_model.bin.index.json') | |
with open(f'{merged_model_folder}/pytorch_model.bin.index.json', 'w+') as f: | |
f.write(json.dumps({"metadata": {"total_size": model_size_bytes}, "weight_map": bin_weight_map}, sort_keys=True, indent=4)) | |
if (len(backlog_layers['model_1']) or len(backlog_layers['model_2']) or len(backlog_layers['model_02']) or len(backlog_layers['model_01'])): | |
print('[WARN] Not all layers were merged, model might be in a broken state') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment