Created
September 22, 2022 01:40
-
-
Save rllin/6728e2321fcc18448b3af793fe6f6eec to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
from typing import List, Optional | |
import torch | |
from torch import distributed as dist | |
from torch.distributed.elastic.multiprocessing.errors import record | |
from torch.utils.data import IterableDataset | |
from torchrec.datasets.criteo import DEFAULT_CAT_NAMES, DEFAULT_INT_NAMES | |
from torchrec.datasets.random import RandomRecDataset | |
from torchrec.distributed import TrainPipelineSparseDist | |
from torchrec.distributed.model_parallel import DistributedModelParallel | |
from torchrec.models.dlrm import DLRM, DLRMTrain | |
from torchrec.modules.embedding_configs import DataType, EmbeddingBagConfig | |
from torchrec.modules.embedding_modules import EmbeddingBagCollection | |
from torchrec.modules.fused_embedding_modules import fuse_embedding_optimizer | |
from torchrec.optim.keyed import KeyedOptimizerWrapper | |
from torchrec.optim.rowwise_adagrad import RowWiseAdagrad | |
from tqdm import tqdm | |
from torchsnapshot import Snapshot | |
import torchrec.quant as trec_quant | |
from torch.distributed._shard.sharding_spec import ChunkShardingSpec | |
import torchsnapshot | |
from torchsnapshot.io_preparer import TensorIOPreparer | |
from torchsnapshot.io_preparer import ShardedTensorIOPreparer | |
from torchsnapshot.manifest import Shard, ShardedTensorEntry | |
import tempfile | |
import functools | |
def start_multi(_per_rank, nprocs: int = 2): | |
with tempfile.TemporaryDirectory() as tmpdir: | |
torch.multiprocessing.start_processes( | |
functools.partial(_per_rank, tmpdir=tmpdir), | |
(nprocs,), | |
nprocs=nprocs, | |
start_method="fork", | |
) | |
def distributed(nprocs: int = 2): | |
def wrapper(func): | |
def _inner(*args, **kwargs): | |
def _setup_ddp(rank: int, world_size: int, tmpdir: str) -> None: | |
""" | |
Setup DDP worker. | |
""" | |
init_file = f"file://{os.path.join(tmpdir, 'init_file')}" | |
dist.init_process_group( | |
backend="gloo", | |
rank=rank, | |
world_size=world_size, | |
init_method=init_file, | |
) | |
func(*args, **kwargs) | |
dist.destroy_process_group() | |
start_multi(_setup_ddp, nprocs=nprocs) | |
return _inner | |
return wrapper | |
@distributed(4) | |
def hello(): | |
print(f"{dist.get_rank()}, hello") | |
@distributed(1)#record | |
def train( | |
#num_embeddings: int = 10, | |
num_embeddings: int = 1024, | |
embedding_dim: int = 128, | |
#embedding_dim: int = 8, | |
) -> None: | |
device = torch.device('cpu') | |
table_names = ["feature1", "feature2"] | |
eb_configs = [ | |
EmbeddingBagConfig( | |
name=f"t_{feature_name}", | |
embedding_dim=embedding_dim, | |
num_embeddings=num_embeddings, | |
feature_names=[feature_name], | |
#data_type=DataType.FP16, | |
) | |
for feature_name in table_names | |
] | |
original_ebc = EmbeddingBagCollection(tables=eb_configs, device=torch.device('meta')) | |
original_ebc = DistributedModelParallel(module=original_ebc, device=device) | |
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor | |
features = KeyedJaggedTensor( | |
keys=["feature1", "feature2"], | |
values=torch.arange(num_embeddings), | |
lengths=torch.ones(num_embeddings).int(), | |
) | |
original_lookups = original_ebc(features).values() | |
torch.testing.assert_close(original_ebc(features).values(), original_lookups) | |
print('before save lookup', original_ebc(features).values()[0][:10]) | |
''' | |
# this should be ~3.5 mb vs 14mb for unquantized | |
# to generate unquantized, remove the monkeypatching on ShardedTensorIOPreparer | |
#print(original_ebc.state_dict()) | |
weights = torch.cat( | |
[ | |
torch.flatten(_tensor.local_tensor()) for _tensor in original_ebc.state_dict().values() | |
] | |
) | |
#weights = original_ebc.state_dict()['embedding_bags.t_feature1.weight'].local_tensor() | |
weight_min, weight_max = torch.min(weights), torch.max(weights) | |
scale = (weight_max - weight_min) / (63 - -64) | |
zero_point = -64 - (weight_min / scale) | |
print(scale, zero_point) | |
''' | |
# function that looks at the model path and decides if module should be quantized | |
# in this simple example, everything has "embedding" in the path, so it's all getting quantized | |
def to_quantize(path: str) -> bool: | |
""" Returns `True` if `path` is a module that should be quantized. | |
""" | |
return "embedding" in path | |
# This part would be shoved into torchsnapshot as a utility | |
def make_custom_tensor_prepare_func(to_quantize): | |
quantized_dtype = torch.qint8 | |
def custom_tensor_prepare_func(path: str, tensor:torch.Tensor, tracing:bool=False): | |
# in tracing mode, return a MetaTensor with correct dtype/size but do not allocate the memory | |
# otherwise, perform the op | |
if to_quantize(path): | |
if tracing: | |
return torch.tensor(tensor, dtype=quantized_dtype, device='meta') | |
else: | |
observer = torch.quantization.observer.MinMaxObserver(dtype=quantized_dtype) | |
observer(tensor) | |
scale, zero_point = observer.calculate_qparams() | |
print(scale, zero_point) | |
return torch.quantize_per_tensor(tensor, scale, zero_point, quantized_dtype) | |
else: | |
return tensor | |
return custom_tensor_prepare_func | |
import shutil | |
if os.path.exists('./base'): | |
shutil.rmtree('./base') | |
if os.path.exists('./quant'): | |
shutil.rmtree('./quant') | |
base_snapshot = Snapshot.take(path="./base", app_state={"model": original_ebc}) | |
quant_snapshot = Snapshot.take( | |
path="./quant", | |
app_state={"model": original_ebc}, | |
_custom_tensor_prepare_func=make_custom_tensor_prepare_func(to_quantize), | |
) | |
#Snapshot.take(path="./quant", app_state={"model": original_ebc}, quantize=False) | |
ebc = EmbeddingBagCollection(tables=eb_configs, device=torch.device('meta')) | |
ebc = DistributedModelParallel(module=ebc, device=device) | |
qebc = EmbeddingBagCollection(tables=eb_configs, device=torch.device('meta')) | |
qebc = DistributedModelParallel(module=ebc, device=device) | |
try: | |
torch.testing.assert_close(ebc(features).values(), original_lookups) | |
except AssertionError: | |
pass | |
base_snapshot.restore(app_state={"model": ebc}) | |
print(ebc.module.state_dict()['embedding_bags.t_feature1.weight'].dtype) | |
quant_snapshot.restore(app_state={"model": qebc}) | |
print(qebc.module.state_dict()['embedding_bags.t_feature1.weight'].dtype) | |
print('old weight', original_ebc.state_dict()['embedding_bags.t_feature1.weight'].local_tensor()) | |
print('new weight', qebc.state_dict()['embedding_bags.t_feature1.weight'].local_tensor()) | |
print(qebc(features).values()) | |
print(original_lookups) | |
torch.testing.assert_close(original_ebc(features).values(), original_lookups) | |
torch.testing.assert_close(qebc(features).values(), original_lookups, rtol=1e-3, atol=1e-3) | |
print("*******************") | |
print("*******************") | |
from pathlib import Path | |
base_size = sum(f.stat().st_size for f in Path('./base').glob('**/*') if f.is_file()) | |
quant_size = sum(f.stat().st_size for f in Path('./quant').glob('**/*') if f.is_file()) | |
print(base_size) | |
print(quant_size) | |
print(quant_size / base_size) | |
if __name__ == "__main__": | |
train() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
On line 139, I think you want to move the tensor to cpu before quantizing.