Python使用技巧--python装饰器的使用

一、函数装饰器的运用

示例一:编写计时装饰器

1.简易版装饰器:该装饰器支持装饰普通方法,也支持类中的方法,但是不支持传入装饰器参数;

def timer(func):
    """
    用于对函数计时,不支持传入装饰器参数;
    该装饰器支持装饰普通方法,也支持类中的方法
    """
    def onCall(*args, **kargs):
        start = time.time()
        result = func(*args, **kargs)
        elapsed = time.time() - start
        onCall.alltime += elapsed
        format = '%s run time: %.5f s;total run time:%.5f s'
        values = (func.__name__, elapsed, onCall.alltime)
        print(format % values)
        return result
    onCall.alltime = 0
    return onCall

运行测试:

# 测试普通函数
@timer
def test():
    time.sleep(1)
    print("run test")

test()
test()

# 输出
run test
test run time: 1.00509 s;total run time:1.00509 s
run test
test run time: 1.00507 s;total run time:2.01015 s

#测试类中的方法
class Test():
    @timer
    def test_method(self):
        time.sleep(1)
        print("run test_method")

a = Test()
a.test_method()
a.test_method()
#输出
run test_method
test_method run time: 1.00507 s;total run time:1.00507 s
run test_method
test_method run time: 1.00509 s;total run time:2.01016 s

示例一 加强版

该装饰器支持装饰普通方法,支持类中的方法,同时也支持传入装饰器参数

def timer(label='', trace=True):
    """
    用于对函数计时,支持传入装饰器参数;
    该装饰器支持装饰普通方法,也支持类中的方法
    """
    def onDecorator(func):
        def onCall(*args, **kargs):
            start = time.time()
            result = func(*args, **kargs)
            elapsed = time.time() - start
            onCall.alltime += elapsed
            if trace:
                format = '%s%s run time: %.5f s;total run time:%.5f s'
                values = (label, func.__name__, elapsed, onCall.alltime)
                print(format % values)
            return result

        onCall.alltime = 0
        return onCall

    return onDecorator

运行测试:

# 测试普通函数
@timer(label="Fun test===>")
def test():
    time.sleep(1)
    print("run test")

test()
test()

#输出:
run test
Fun test===>test run time: 1.00525 s;total run time:1.00525 s
run test
Fun test===>test run time: 1.00097 s;total run time:2.00623 s


# 测试类中的方法
class Test():
    @timer(label="Class method test===>")
    def test_method(self):
        time.sleep(1)
        print("run test_method")

a = Test()
a.test_method()
a.test_method()

#输出:
run test_method
Class method test===>test_method run time: 1.00361 s;total run time:1.00361 s
run test_method
Class method test===>test_method run time: 1.00507 s;total run time:2.00868 s

示例二:参数验证装饰器

1.简易版:只有范围验证的功能

# -*-coding:utf-8-*-

def rangetest(trace=True, **argchecks):  # Validate ranges for both+defaults
    def onDecorator(func):  # onCall remembers func and argchecks
        if not __debug__:  # True if "python -O main.py args..."
            return func  # Wrap if debugging; else use original
        else:
            import sys
            code = func.__code__
            allargs = code.co_varnames[:code.co_argcount]
            funcname = func.__name__

            def onCall(*pargs, **kargs):
                # All pargs match first N expected args by position
                # The rest must be in kargs or be omitted defaults
                positionals = list(allargs)
                positionals = positionals[:len(pargs)]

                for (argname, (low, high)) in argchecks.items():
                    # For all args to be checked
                    if argname in kargs:
                        # Was passed by name
                        if kargs[argname] < low or kargs[argname] > high:
                            errmsg = '{0} argument "{1}" not in {2}..{3}'
                            errmsg = errmsg.format(funcname, argname, low, high)
                            raise TypeError(errmsg)

                    elif argname in positionals:
                        # Was passed by position
                        position = positionals.index(argname)
                        if pargs[position] < low or pargs[position] > high:
                            errmsg = '{0} argument "{1}" not in {2}..{3}'
                            errmsg = errmsg.format(funcname, argname, low, high)
                            raise TypeError(errmsg)
                    else:
                        # Assume not passed: default
                        if trace:
                            print('Argument "{0}" defaulted'.format(argname))

                return func(*pargs, **kargs)  # OK: run original call

            return onCall
    return onDecorator

运行测试:

测试普通函数

@rangetest(age=(0, 120))  # persinfo = rangetest(..)(persinfo)
def persinfo(name, age):
    print('%s is %s years old' % (name, age))


@rangetest(M=(1, 12), D=(1, 31), Y=(0, 2009))
def birthday(M, D, Y):
    print('birthday = {0}/{1}/{2}'.format(M, D, Y))


persinfo('Bob', 40)
birthday(5, D=1, Y=1963)
# 输出
Bob is 40 years old
birthday = 5/1/1963


persinfo('Bob', 150)
# 输出
Traceback (most recent call last):
  File "/Users/edwin/PycharmProjects/testProject/test.py", line 64, in <module>
    persinfo('Bob', 150)
  File "/Users/edwin/PycharmProjects/testProject/test.py", line 37, in onCall
    raise TypeError(errmsg)
TypeError: persinfo argument "age" not in 0..120

测试类中的方法

# 测试类中的方法
class Person:
    def __init__(self, name, job, pay):
        self.job = job
        self.pay = pay
        # giveRaise = rangetest(..)(giveRaise)

    @rangetest(percent=(0.0, 1.0))  # percent passed by name or position
    def giveRaise(self, percent):
        self.pay = int(self.pay * (1 + percent))


bob = Person('Bob Smith', 'dev', 100000)
sue = Person('Sue Jones', 'dev', 100000)
bob.giveRaise(0.10)
sue.giveRaise(percent=0.20)
print(bob.pay, sue.pay)
# 输出
110000 120000


bob.giveRaise(1.10)
# 输出
Traceback (most recent call last):
  File "/Users/edwin/PycharmProjects/testProject/test.py", line 84, in <module>
    bob.giveRaise(1.10)
  File "/Users/edwin/PycharmProjects/testProject/test.py", line 34, in onCall
    raise TypeError(errmsg)
TypeError: giveRaise argument "percent" not in 0.0..1.0

示例二加强版

可以处理范围测试,类型测试,值测试

def rangetest(trace=True, **argchecks):
   return argtest(argchecks, lambda arg, vals: arg < vals[0] or arg > vals[1], trace=trace)


def typetest(trace=True, **argchecks):
   return argtest(argchecks, lambda arg, type: not isinstance(arg, type), trace=trace)


def valuetest(trace=True, **argchecks):
   return argtest(argchecks, lambda arg, tester: not tester(arg), trace=trace)


def argtest(argchecks, failif, trace):  # Validate ranges for both+defaults
   def onDecorator(func):  # onCall remembers func and argchecks
       if not __debug__:  # True if "python -O main.py args..."
           return func  # Wrap if debugging; else use original
       else:
           code = func.__code__
           allargs = code.co_varnames[:code.co_argcount]
           funcname = func.__name__

           def onError(argname, criteria):
               errfmt = '%s argument "%s" not %s'
               raise TypeError(errfmt % (funcname, argname, criteria))

           def onCall(*pargs, **kargs):
               # All pargs match first N expected args by position
               # The rest must be in kargs or be omitted defaults
               positionals = list(allargs)
               positionals = positionals[:len(pargs)]
               for (argname, criteria) in argchecks.items():
                   # 关键字参数检查
                   if argname in kargs:
                       # Was passed by name
                       if failif(kargs[argname], criteria):
                           onError(argname, criteria)
                   # 位置参数检查
                   elif argname in positionals:
                       # Was passed by position
                       position = positionals.index(argname)
                       if failif(pargs[position], criteria):
                           onError(argname, criteria)
                   else:
                       # Assume not passed: default
                       if trace:
                           print('Argument "{0}" defaulted'.format(argname))

               return func(*pargs, **kargs)  # OK: run original call
           return onCall
   return onDecorator

运行测试:

import sys
def fails(test):
  try:
    result = test()
    except:
      print("[%s]" % sys.exc_info()[1])
      else:
        print('?%s?' % result)
       
@rangetest(M=(1, 12), D=(1, 31), Y=(1900, 2013))
def date(M, D, Y):
    print('date = {0}/{1}/{2}'.format(M, D, Y))

date(5, 1, 1960)
date(M=5, D=1, Y=1960)
fails(lambda: date(1, 2, 3))
print("---------------------------------------------")
@typetest(a=int, c=float)
def sum(a, b, c, d):
    print(a+b+c+d)

sum(1, 2, 3.0, 4)
fails(lambda: sum('spam', 2, 3, 4))
print("---------------------------------------------")

@valuetest(word1=str.islower, word2=(lambda x: x[0].isupper()))
def msg(word1, word2):
    print("%s %s" % (word1, word2))

msg('edwin', 'Edwin')
fails(lambda: msg('Edwin', 'EdWin'))
print("---------------------------------------------")

@rangetest(X=(1, 10))
@typetest(Z=str)
def nester(X, Y, Z):
    print("%s %s %s" % (X, Y, Z))

nester(1, 2, "edwin")
fails(lambda: nester(1, 2, 1))

输出:

date = 5/1/1960
date = 5/1/1960
[date argument "Y" not (1900, 2013)]
---------------------------------------------
10.0
[sum argument "a" not <class 'int'>]
---------------------------------------------
edwin Edwin
[msg argument "word1" not <method 'islower' of 'str' objects>]
---------------------------------------------
Argument "X" defaulted
1 2 edwin
Argument "X" defaulted
[nester argument "Z" not <class 'str'>]

二、类装饰器的运用

示例一:实现单例功能的类装饰器

只适合用于python3环境,因为nonlocal语句仅在python3.x中可用

def singleton(aClass):
    """
    管理一个类只能创建一个实例
    只适合用于python3环境,因为nonlocal语句仅在python3.x中可用
    :param aClass: 装饰的类
    :return:
    """
    instance = None
    def onCall(*args):
        nonlocal instance
        if instance == None:
            instance = aClass(*args)
        return instance
    return onCall

适合用于python2和python3环境

def singleton(aClass):
    """
    管理一个类只能创建一个实例
    适合用于python2和python3环境
    :param aClass: 装饰的类
    :return:
    """

    def onCall(*args):
        if onCall.instance == None:
            onCall.instance = aClass(*args)
        return onCall.instance

    onCall.instance = None
    return onCall

运行测试:

if __name__ == '__main__':
    @singleton  # Person = singleton(Person)
    class Person:
        def __init__(self, name, hours, rate):
            self.name = name
            self.hours = hours
            self.rate = rate

        def pay(self):
            return self.hours * self.rate



    bob = Person('Bob', 40, 10)  # Really calls onCall
    print(bob.name, bob.pay())

    sue = Person('Sue', 50, 20)  # Same, single object
    print(sue.name, sue.pay())

输出:

Bob 400
Bob 400

示例二:类属性访问装饰器

只适合用于python2的装饰器

traceMe = False
def trace(*args):
    if traceMe: print('[' + ' '.join(map(str, args)) + ']')

def accessControl(failIf):
    def onDecorator(aClass):
        if not __debug__:
            return aClass
        else:
            class onInstance:
                def __init__(self, *args, **kargs):
                    self.__wrapped = aClass(*args, **kargs)
                def __getattr__(self, attr):
                    trace('get:', attr)
                    if failIf(attr):
                        raise TypeError('private attribute fetch: ' + attr)
                    else:
                        return getattr(self.__wrapped, attr)
                def __setattr__(self, attr, value):
                    trace('set:', attr, value)
                    if attr == '_onInstance__wrapped':
                        self.__dict__[attr] = value
                    elif failIf(attr):
                        raise TypeError('private attribute change: ' + attr)
                    else:
                        setattr(self.__wrapped, attr, value)
            return onInstance
    return onDecorator


def Private(*attributes):
    return accessControl(failIf=(lambda attr: attr in attributes))


def Public(*attributes):
    return accessControl(failIf=(lambda attr: attr not in attributes))

测试:

if __name__ == '__main__':
    import sys
    @Private('age')  # Person = Private('age')(Person)
    class Person:  # Person = onInstance with state
        def __init__(self, name, age):
            self.name = name
            self.age = age  # Inside accesses run normally

        def __add__(self, other):
            self.age += other

        def __str__(self):
            return '%s: %s' % (self.name, self.age)

    X = Person('Bob', 40)
    print(X.name)  # Outside accesses validated
    X.name = 'Sue'
    print(X.name)
    X + 10
    print(X)
    try:
        t = X.age
    except:
        print("Error:[%s]" % sys.exc_info()[1])

这里当运行到X+10这条语句时,便会报错:

Traceback (most recent call last):
  File "/Users/edwin/PycharmProjects/testProject/test.py", line 75, in <module>
    X + 10
TypeError: unsupported operand type(s) for +: 'onInstance' and 'int'

异常分析:

当在python2下运行时,代理类(onInstance)是一个经典类,但是当在python3下运行时,代理类是一个新式类。(python3中只有新式类,没有经典类)。当通过内置操作隐式地运行时(X+10),在经典类中,会触发代理类(onInstance)中__getattr__的调用,在新式类中,不会触发代理类(onInstance)中__getattr__的调用,从而不会调用到Person类中的__add__。详细细节请参考另一文章<<Python使用技巧--拦截内置运算属性>>。。

注意:在python2的默认经典类中,__getattr__会拦截内置函数对__add__和__str__这样的运算符重载方法的隐式访问,但是在python3的新式类中不会拦截(包括python2的新式类)。

解决方法:

在代理类中重新定义__add__这些运算符重载方法。

示例二改进版

适合用于python2和python3的装饰器

以下装饰器使用了一个混合技巧来为包装器类添加一些运算符重载方法的重定义,这样在python3.x中它会正确地将内置操作委托到使用这些方法的主体类上。

# -*-coding:utf-8-*-

traceMe = False
def trace(*args):
    if traceMe: print('[' + ' '.join(map(str, args)) + ']')

def accessControl(failIf):
    def onDecorator(aClass):
        if not __debug__:
            return aClass
        else:
            class onInstance(BuiltinsMixin):
                def __init__(self, *args, **kargs):
                    self.__wrapped = aClass(*args, **kargs)
                def __getattr__(self, attr):
                    trace('get:', attr)
                    if failIf(attr):
                        raise TypeError('private attribute fetch: ' + attr)
                    else:
                        return getattr(self.__wrapped, attr)
                def __setattr__(self, attr, value):
                    trace('set:', attr, value)
                    if attr == '_onInstance__wrapped':
                        self.__dict__[attr] = value
                    elif failIf(attr):
                        raise TypeError('private attribute change: ' + attr)
                    else:
                        setattr(self.__wrapped, attr, value)
            return onInstance
    return onDecorator


def Private(*attributes):
    return accessControl(failIf=(lambda attr: attr in attributes))


def Public(*attributes):
    return accessControl(failIf=(lambda attr: attr not in attributes))

class BuiltinsMixin():
    def reroute(self, attr, *args, **kargs):
        return self.__class__.__getattr__(self, attr)(*args, **kargs)

    def __add__(self, other):
        return self.reroute('__add__', other)

    def __str__(self):
        return self.reroute('__str__')

    def __getitem__(self, index):
        return self.reroute('__getitem__', index)

    def __call__(self, *args, **kargs):
        return self.reroute('__call__', *args, **kargs)

测试:

if __name__ == '__main__':
    import sys
    @Private('age')  # Person = Private('age')(Person)
    class Person:  # Person = onInstance with state
        def __init__(self, name, age):
            self.name = name
            self.age = age  # Inside accesses run normally

        def __add__(self, other):
            self.age += other

        def __str__(self):
            return '%s: %s' % (self.name, self.age)

    X = Person('Bob', 40)
    print(X.name)  # Outside accesses validated
    X.name = 'Sue'
    print(X.name)
    X + 10
    print(X)
    try:
        t = X.age
    except:
        print("Error:[%s]" % sys.exc_info()[1])

输出:

Bob
Sue
Sue: 50
Error:[private attribute fetch: age]

三、编写装饰器的注意事项

1.保持多个装饰的实例

我们都是知道,编写装饰器的时候,可以使用函数,也可以使用类来编写,但是当使用类来编写的时候,我们需要注意装饰的实例被覆盖。

以下装饰器实现属性调用的追踪。

class Tracer:
    def __init__(self, aClass):
      	self.fetches = 0
        self.aClass = aClass
    def __call__(self, *args):
        self.wrapped = self.aClass(*args)
        return self
    def __getattr__(self, attrname):
        print('Trace: ' + attrname)
        self.fetches += 1
        return getattr(self.wrapped, attrname)

测试

@Tracer
class Person:                                 # Person = Tracer(Person)
    def __init__(self, name):                 # Wrapper bound to Person
        self.name = name

bob = Person('Bob')
print(bob.name)
Sue = Person('Sue') #bob实例被sue实例
print(sue.name)
print(bob.name) # bob实例的name='Sue'!

分析:

每个实例构建调用会触发__call__,这会覆盖前面的实例。直接效果是Tracer只保存了一个实例,即最后创建的那个实例。

改进:基于函数的装饰器可用于多个实例,因为每个实例构造调用都会创建一个新的Wrapper实例,而不是覆盖一个单个共享的Tracer实例的状态。

def Tracer(aClass):
    class Wrapper:
        def __init__(self):
            self.fetches = 0
            self.aClass = aClass

        def __call__(self, *args, **kwargs):
            self.wrapped = self.aClass(*args, **kwargs)
            return self

        def __getattr__(self, attrname):
            print('Trace: ' + attrname)
            self.fetches += 1
            return getattr(self.wrapped, attrname)
    return Wrapper

2.对类方法进行装饰

我们编写以下一个装饰器:

class tracer:
    def __init__(self, func):
        self.calls = 0
        self.func = func
    def __call__(self, *args):
        self.calls += 1
        print('call %s to %s' % (self.calls, self.func.__name__))
        self.func(*args)

装饰普通函数没问题:

@tracer
def spam(a, b, c):           # spam = tracer(spam)
    print(a + b + c)         # Wraps spam in a decorator object
spam(1, 2, 3) 
spam('a', 'b', 'c') 

输出:

call 1 to spam
6
call 2 to spam
abc

当装饰类中的方法,就失效了。

if __name__ == '__main__':
    class Person:
        def __init__(self, name, pay):
            self.name = name
            self.pay = pay

        @tracer
        def giveRaise(self, percent):  # giveRaise = tracer(giverRaise)
            self.pay *= (1.0 + percent)

        @tracer
        def lastName(self):  # lastName = tracer(lastName)
            return self.name.split()[-1]
    bob = Person('Bob Smith', 50000)
    bob.giveRaise(0.25)  # Runs tracer.__call__(???, .25)
    print(bob.lastName())  # Runs tracer.__call__(???)

输出:

Traceback (most recent call last):
  File "/Users/edwin/PycharmProjects/testProject/test.py", line 26, in <module>
    bob.giveRaise(0.25)
  File "/Users/edwin/PycharmProjects/testProject/test.py", line 10, in __call__
    self.func(*args)
TypeError: giveRaise() missing 1 required positional argument: 'percent'
call 1 to giveRaise

分析:

这里用__call__把被装饰方法名称重绑定到一个类实例对象的时候,python只向self传递了tracer实例,它根本没有在参数列表中传递Person主体。因此tracer不知道我们要利用方法调用处理的Person实例的任何信息,导致没办法创建一个带有实例的绑定方法,也没办法正确地分发调用。这是一个非常值得注意的细节。

改进:使用嵌套函数装饰方法

def tracer(func):
    calls = 0
    def onCall(*args, **kwargs):
        nonlocal calls
        calls += 1
        print('call %s to %s' % (calls, func.__name__))
        return func(*args, **kwargs)
    return onCall

测试:

if __name__ == '__main__':
    class Person:
        def __init__(self, name, pay):
            self.name = name
            self.pay = pay

        @tracer
        def giveRaise(self, percent):  # giveRaise = tracer(giverRaise)
            self.pay *= (1.0 + percent)

        @tracer
        def lastName(self):  # lastName = tracer(lastName)
            return self.name.split()[-1]


    print('methods...')
    bob = Person('Bob Smith', 50000)
    sue = Person('Sue Jones', 100000)
    print(bob.name, sue.name)
    sue.giveRaise(.10)  # Runs onCall(sue, .10)
    print(sue.pay)
    print(bob.lastName(), sue.lastName())  # Runs onCall(bob), lastName in scopes

输出:

methods...
Bob Smith Sue Jones
call 1 to giveRaise
110000.00000000001
call 1 to lastName
call 2 to lastName
Smith Jones

posted on 2021-11-07 15:54  xufat  阅读(254)  评论(0)    收藏  举报

导航

/* 返回顶部代码 */ TOP