
'''LRU implementation by Imri Goldberg,

http://www.algorithm.co.il/blogs/index.php/programming/python/lru-cache-solution-a-case-for-linked-lists-in-python/
bugfix and test code by Aaron Digulla.
'''

from weakref import proxy
import unittest
import threading

class Link(object):
    def __init__(self, data, prev, next):
        self.data,self.prev,self.next = data,prev,next

    def __repr__(self):
        return super(Link, self).__repr__() + ' ' + repr(self.data)

    def __str__(self):
        return repr(self.data)

class LinkList(object):
    def __init__(self, start_list = None):
        self.start = None
        self.end = None
        if start_list is not None:
            self.add_list(start_list)

    def __iter__(self):
        link = self.start
        while link != None:
            yield link
            link = link.next

    def add_list(self, some_list):
        for element in some_list:
            self.append(element)

    def __getitem__(self, n):
        return self.get_link(n)

    def get_link(self, n):
        for i, link in enumerate(self):
            if i == n:
                return link
        raise KeyError("no such element in list")

    def append(self, data):
        prev_end = self.end
        if prev_end is not None:
            prev_end = prev_end
        new_link = Link(data, prev_end, None)
        if self.end is not None:
            self.end.next = new_link
        self.end = new_link
        if self.start is None:
            self.start = new_link

    add_last = append

    def add_first(self, data):
        new_link = Link(data, None, self.start)
        if self.start is not None:
            self.start.prev = new_link
        self.start = new_link
        if self.end is None:
            self.end = new_link

    def insert_after(self, new_data, other_link):
        assert other_link is not None
        old_next = other_link.next
        new_link = Link(new_data, other_link, old_next)
        other_link.next = new_link
        if old_next is not None:
            old_next.prev = new_link
        else:
            self.end = new_link

    def insert_before(self, new_data, other_link):
        assert other_link is not None
        old_prev = other_link.prev
        if old_prev is not None:
            old_prev = old_prev
        new_link = Link(new_data, old_prev, other_link)
        other_link.prev = new_link
        if old_prev is not None:
            old_next.next = new_link
        else:
            self.start = new_link

    def remove_last(self):
        if self.end is not None:
            before_last = self.end.prev
            if before_last is not None:
                before_last.next = None
            else:
                self.start = None
            self.end = before_last

    def remove_first(self):
        if self.start is not None:
            after_first = self.start.next
            if after_first is not None:
                before_last.prev = None
            else:
                self.end = None
            self.start = after_first

    def remove(self, link):
        assert link is not None
        old_prev = link.prev
        old_next = link.next
        if old_next is not None:
            old_next.prev = old_prev
        else:
            self.end = old_prev
        if old_prev is not None:
            old_prev.next = old_next
        else:
            self.start = old_next

    def __len__(self):
        link = self.start
        count = 0
        while link != None:
            count += 1
            link = link.next

        return count

class _NotInDict(object):
    pass

_NotInDict = _NotInDict()

class _CacheData(object):
    def __init__(self, data, link):
        self.data = data
        self.link = link

    def __str__(self):
        return '_CacheData(%s)' % repr(self.link)

class LRUCache(object):
    def __init__(self, max_size):
        self.max_size = max_size
        self.dict = {}
        self.age_list = LinkList()
        self.lock = threading.Lock()

    def __len__(self):
        return len(self.dict)

    def __contains__(self, key):
        return key in self.dict

    def get(self, key, default = None):
        self.verify()
        result = self.dict.get(key, _NotInDict)
        if result is _NotInDict:
            return default

        link = result.link
        self.age_list.remove(link)
        self.age_list.add_first(key)
        link = self.age_list.start
        result.link = link
        self.verify()

        return result.data

    def verify(self):
        lSet = set()
        dump = False
        for link in self.age_list:
            data = link.data
            if data in lSet:
                print '!!! Duplicate item in age_list:',data
                dump = True
            lSet.add(data)
            cd = self.dict.get(data, None)
            if cd is None:
                print '!!! Link not found in dict:',repr(data)
                dump = True
            elif cd.link != link:
                print '!!! Link/data mismatch in list:',repr(link),repr(cd.link)
                print str(link),str(cd.link)
                dump = True
        for key, cd in self.dict.items():
            if key != cd.link.data:
                print '!!! Link/data mismatch in dict:',key,cd.link.data
                dump = True

        dSet = set(self.keys())
        if lSet != dSet:
            s = lSet - dSet
            if s:
                print '!!! Elements only in age_list:',s
                dump = True
            s = dSet - lSet
            if s:
                print '!!! Elements only in dict:',s
                dump = True
        if dump:
            l = self.keys()
            l.sort()
            print 'Dict:'
            for key in l:
                print repr(key)
            print 'List:'
            l = list(self.age_list)
            l.sort()
            for link in l:
                print repr(link)
            raise Exception('Verify failed!')

    def __getitem__(self, key):
        result = self.get(key, _NotInDict)
        if result is _NotInDict:
            raise KeyError(key)
        return result

    def remove_oldest(self):
        oldest_link = self.age_list.end
        if oldest_link is None:
            return
        self.verify()
        key = oldest_link.data
        del self.dict[key]
        self.age_list.remove(oldest_link)
        self.verify()

    def _fit_to_size(self):
        while len(self) > self.max_size:
            self.remove_oldest()

    def __setitem__(self, key, value):
        self.verify()
        self.age_list.add_first(key)
        link = self.age_list.start

        old_data = self.dict.get(key, _NotInDict)
        if old_data is _NotInDict:
            new_data = _CacheData(value, link)
        else:
            self.age_list.remove(old_data.link)
            new_data = old_data
            new_data.data = value
            new_data.link = link

        self.dict[key] = new_data
        self.verify()

        self._fit_to_size()

    def keys(self):
        return self.dict.keys()

def cached(max_size, func = None):
    def decorator(func):
        cache = LRUCache(max_size)
        def wrapper_func(*args):
            prev_result = cache.get(args, _NotInDict)
            if prev_result is _NotInDict:
                result = func(*args)
                cache[args] = result
                return result
            return prev_result
        return wrapper_func

    if func is None:
        return decorator
    else:
        return decorator(func)

class TestLRUCache(unittest.TestCase):
    def testPut(self):
        lru = LRUCache(4)
        lru[1] = 'a'
        lru[2] = 'b'
        lru[3] = 'c'
        lru[4] = 'd'
        lru[5] = 'e'

        self.assertEquals(4, len(lru))
        keys = lru.keys()
        keys.sort()
        self.assertEquals([2,3,4,5], keys)

    def testGet(self):
        lru = LRUCache(4)
        lru[1] = 'a'
        lru[2] = 'b'
        lru[3] = 'c'
        lru[4] = 'd'
        lru[1]
        lru[5] = 'e'

        self.assertEquals(4, len(lru))
        keys = lru.keys()
        keys.sort()
        self.assertEquals([1,3,4,5], keys)

    def testManyPut(self):
        import gc

        lru = LRUCache(4)
        for i in range(1000):
            lru[i] = i

        lru[1] = 1
        lru[1] = 1
        lru[1] = 1
        lru[1] = 1

        self.assertEquals(4, len(lru))
        keys = lru.keys()
        keys.sort()
        self.assertEquals([1, 997, 998, 999], keys)

if __name__ == '__main__':
    unittest.main()
