From 7b2cf5c6a8e121f2009b416efa6ef1807f4163ae Mon Sep 17 00:00:00 2001 From: Vadim Pushtaev Date: Thu, 11 Oct 2018 01:32:28 +0300 Subject: [PATCH] [issue #123] decorated property with new name support --- cached_property.py | 20 +++++++++++++----- tests/test_cached_property.py | 39 +++++++++++++++++++++++++---------- 2 files changed, 43 insertions(+), 16 deletions(-) diff --git a/cached_property.py b/cached_property.py index 125f619..55fa1ce 100644 --- a/cached_property.py +++ b/cached_property.py @@ -24,6 +24,10 @@ class cached_property(object): def __init__(self, func): self.__doc__ = getattr(func, "__doc__") self.func = func + self.name = func.__name__ + + def __set_name__(self, owner, name): + self.name = name def __get__(self, obj, cls): if obj is None: @@ -32,7 +36,7 @@ def __get__(self, obj, cls): if asyncio and asyncio.iscoroutinefunction(self.func): return self._wrap_in_coroutine(obj) - value = obj.__dict__[self.func.__name__] = self.func(obj) + value = obj.__dict__[self.name] = self.func(obj) return value def _wrap_in_coroutine(self, obj): @@ -40,7 +44,7 @@ def _wrap_in_coroutine(self, obj): @asyncio.coroutine def wrapper(): future = asyncio.ensure_future(self.func(obj)) - obj.__dict__[self.func.__name__] = future + obj.__dict__[self.name] = future return future return wrapper() @@ -56,21 +60,24 @@ def __init__(self, func): self.__doc__ = getattr(func, "__doc__") self.func = func self.lock = threading.RLock() + self.name = func.__name__ + + def __set_name__(self, owner, name): + self.name = name def __get__(self, obj, cls): if obj is None: return self obj_dict = obj.__dict__ - name = self.func.__name__ with self.lock: try: # check if the value was computed before the lock was acquired - return obj_dict[name] + return obj_dict[self.name] except KeyError: # if not, do the calculation and release the lock - return obj_dict.setdefault(name, self.func(obj)) + return obj_dict.setdefault(self.name, self.func(obj)) class cached_property_with_ttl(object): @@ -89,6 +96,9 @@ def __init__(self, ttl=None): self.ttl = ttl self._prepare_func(func) + def __set_name__(self, owner, name): + self.__name__ = name + def __call__(self, func): self._prepare_func(func) return self diff --git a/tests/test_cached_property.py b/tests/test_cached_property.py index 5d8ea92..3ef2daa 100644 --- a/tests/test_cached_property.py +++ b/tests/test_cached_property.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- +import sys import time import unittest from threading import Lock, Thread @@ -8,12 +9,23 @@ import cached_property -def CheckFactory(cached_property_decorator, threadsafe=False): +def CheckFactory(cached_property_decorator, threadsafe=False, change_name=False): """ Create dynamically a Check class whose add_cached method is decorated by the cached_property_decorator. """ + def add_cached_func(self): + if threadsafe: + time.sleep(1) + # Need to guard this since += isn't atomic. + with self.lock: + self.cached_total += 1 + else: + self.cached_total += 1 + + return self.cached_total + class Check(object): def __init__(self): @@ -26,17 +38,15 @@ def add_control(self): self.control_total += 1 return self.control_total - @cached_property_decorator - def add_cached(self): - if threadsafe: - time.sleep(1) - # Need to guard this since += isn't atomic. - with self.lock: - self.cached_total += 1 - else: - self.cached_total += 1 + if change_name: + def add_cached_orig(self): + return add_cached_func(self) - return self.cached_total + add_cached = cached_property_decorator(add_cached_orig) + else: + @cached_property_decorator + def add_cached(self): + return add_cached_func(self) def run_threads(self, num_threads): threads = [] @@ -124,6 +134,13 @@ def test_set_cached_property(self): self.assertEqual(check.add_cached, "foo") self.assertEqual(check.cached_total, 0) + @unittest.skipUnless(sys.version_info >= (3, 6), 'No __set_name__ support until Python 3.6') + def test_cached_property_change_name(self): + Check = CheckFactory(self.cached_property_factory, change_name=True) + check = Check() + self.assertEqual(check.add_cached, 1) + self.assertEqual(check.add_cached_orig(), 2) + def test_threads(self): Check = CheckFactory(self.cached_property_factory, threadsafe=True) check = Check()