Last active
December 20, 2015 12:49
-
-
Save glciampaglia/6134388 to your computer and use it in GitHub Desktop.
A dictionary class that stores references to duplicate values. Values must be hashable types.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class RefDict(dict): | |
def __init__(self, *args, **kwargs): | |
''' | |
A dictionary that stores references to a set of values. It gives an | |
advantage in terms of memory when one needs to store several large | |
duplicate objects. | |
Note | |
---- | |
Values must be hashable. Instead of lists, use tuples; instead of sets, | |
use frozensets. | |
''' | |
super(RefDict, self).__init__() | |
self._values = {} # mapping of value to list of k : self[k] = value | |
for k, v in args: | |
self.__setitem__(k, v) | |
for k in kwargs: | |
self.__setitem__(k, kwargs[k]) | |
def __setitem__(self, key, value): | |
if key in self: # key exists, update value | |
oldvalue = super(RefDict, self).__getitem__(key) | |
if oldvalue == value: # nothing to do | |
return | |
# remove key from list of keys associated to old value | |
oldkeys = self._values[oldvalue] | |
oldkeys.remove(key) | |
if len(oldkeys) == 0: | |
# remove old value as well | |
del self._values[oldvalue] | |
if value in self._values: | |
# append key to list of keys associated to value; retrieve | |
# instance of value to which key will be mapped to | |
keys = self._values[value] | |
if len(keys) == 0: | |
raise RuntimeError('Cannot retrieve value: {}'.format(value)) | |
k0 = keys[0] | |
v = super(RefDict, self).__getitem__(k0) | |
else: | |
# create new list of keys; key is going to be mapped to this | |
# instance of value | |
keys = [] | |
self._values[value] = keys | |
v = value | |
keys.append(key) | |
super(RefDict, self).__setitem__(key, v) | |
else: # new key : value mapping | |
if value in self._values: # value already exists | |
keys = self._values[value] | |
if len(keys) == 0: | |
raise RuntimeError('Cannot retrieve value: {}'.format(value)) | |
keys.append(key) | |
k0 = keys[0] | |
v = super(RefDict, self).__getitem__(k0) | |
super(RefDict, self).__setitem__(key, v) | |
else: # new value | |
self._values[value] = [key] | |
super(RefDict, self).__setitem__(key, value) | |
def __delitem__(self, key): | |
v = super(RefDict, self).__getitem__(key) | |
keys = self._values[v] | |
keys.remove(key) | |
if len(keys) == 0: # remove list of keys if last key | |
del self._values[v] | |
super(RefDict, self).__delitem__(key) # standard del self[k] | |
def __repr__(self): | |
base_rep = super(RefDict, self).__repr__() | |
return '<RefDict {} at 0x{:x} ({} keys, {} unique values)>'.format( | |
base_rep, id(self), len(self), len(self._values)) | |
from nose.tools import raises | |
def test_refdict_base(): | |
rd = RefDict() | |
rd[1] = 1 | |
assert rd[1] == 1 | |
rd[1] = 2 | |
assert rd[1] == 2 | |
assert len(rd._values) == 1 | |
assert len(rd._values[2]) == 1 | |
assert 1 not in rd._values | |
def test_refdict_dups(): | |
rd = RefDict() | |
for i in xrange(10): | |
rd[i] = i % 2 | |
assert len(rd) == 10 | |
assert len(rd._values) == 2 | |
@raises(TypeError) | |
def test_refdict_unhashable(): | |
rd = RefDict() | |
rd[1] = list() | |
def test_refdict_deletion(): | |
rd = RefDict() | |
rd[1] = 1 | |
assert len(rd) == 1 | |
del rd[1] | |
assert len(rd) == 0 | |
assert len(rd._values) == 0 | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment