301 lines
11 KiB
Python
301 lines
11 KiB
Python
|
#!/usr/bin/env python
|
||
|
# -*- coding: utf-8 -*-
|
||
|
|
||
|
from __future__ import absolute_import
|
||
|
from __future__ import division
|
||
|
from __future__ import print_function
|
||
|
from __future__ import unicode_literals
|
||
|
|
||
|
__all__ = ['singledispatch', 'singledispatchmethod']
|
||
|
|
||
|
from weakref import WeakKeyDictionary
|
||
|
|
||
|
from .helpers import MappingProxyType, get_cache_token, get_type_hints, update_wrapper
|
||
|
|
||
|
################################################################################
|
||
|
### singledispatch() - single-dispatch generic function decorator
|
||
|
################################################################################
|
||
|
|
||
|
def _c3_merge(sequences):
|
||
|
"""Merges MROs in *sequences* to a single MRO using the C3 algorithm.
|
||
|
|
||
|
Adapted from http://www.python.org/download/releases/2.3/mro/.
|
||
|
|
||
|
"""
|
||
|
result = []
|
||
|
while True:
|
||
|
sequences = [s for s in sequences if s] # purge empty sequences
|
||
|
if not sequences:
|
||
|
return result
|
||
|
for s1 in sequences: # find merge candidates among seq heads
|
||
|
candidate = s1[0]
|
||
|
for s2 in sequences:
|
||
|
if candidate in s2[1:]:
|
||
|
candidate = None
|
||
|
break # reject the current head, it appears later
|
||
|
else:
|
||
|
break
|
||
|
if candidate is None:
|
||
|
raise RuntimeError("Inconsistent hierarchy")
|
||
|
result.append(candidate)
|
||
|
# remove the chosen candidate
|
||
|
for seq in sequences:
|
||
|
if seq[0] == candidate:
|
||
|
del seq[0]
|
||
|
|
||
|
def _c3_mro(cls, abcs=None):
|
||
|
"""Computes the method resolution order using extended C3 linearization.
|
||
|
|
||
|
If no *abcs* are given, the algorithm works exactly like the built-in C3
|
||
|
linearization used for method resolution.
|
||
|
|
||
|
If given, *abcs* is a list of abstract base classes that should be inserted
|
||
|
into the resulting MRO. Unrelated ABCs are ignored and don't end up in the
|
||
|
result. The algorithm inserts ABCs where their functionality is introduced,
|
||
|
i.e. issubclass(cls, abc) returns True for the class itself but returns
|
||
|
False for all its direct base classes. Implicit ABCs for a given class
|
||
|
(either registered or inferred from the presence of a special method like
|
||
|
__len__) are inserted directly after the last ABC explicitly listed in the
|
||
|
MRO of said class. If two implicit ABCs end up next to each other in the
|
||
|
resulting MRO, their ordering depends on the order of types in *abcs*.
|
||
|
|
||
|
"""
|
||
|
for i, base in enumerate(reversed(cls.__bases__)):
|
||
|
if hasattr(base, '__abstractmethods__'):
|
||
|
boundary = len(cls.__bases__) - i
|
||
|
break # Bases up to the last explicit ABC are considered first.
|
||
|
else:
|
||
|
boundary = 0
|
||
|
abcs = list(abcs) if abcs else []
|
||
|
explicit_bases = list(cls.__bases__[:boundary])
|
||
|
abstract_bases = []
|
||
|
other_bases = list(cls.__bases__[boundary:])
|
||
|
for base in abcs:
|
||
|
if issubclass(cls, base) and not any(
|
||
|
issubclass(b, base) for b in cls.__bases__
|
||
|
):
|
||
|
# If *cls* is the class that introduces behaviour described by
|
||
|
# an ABC *base*, insert said ABC to its MRO.
|
||
|
abstract_bases.append(base)
|
||
|
for base in abstract_bases:
|
||
|
abcs.remove(base)
|
||
|
explicit_c3_mros = [_c3_mro(base, abcs=abcs) for base in explicit_bases]
|
||
|
abstract_c3_mros = [_c3_mro(base, abcs=abcs) for base in abstract_bases]
|
||
|
other_c3_mros = [_c3_mro(base, abcs=abcs) for base in other_bases]
|
||
|
return _c3_merge(
|
||
|
[[cls]] +
|
||
|
explicit_c3_mros + abstract_c3_mros + other_c3_mros +
|
||
|
[explicit_bases] + [abstract_bases] + [other_bases]
|
||
|
)
|
||
|
|
||
|
def _compose_mro(cls, types):
|
||
|
"""Calculates the method resolution order for a given class *cls*.
|
||
|
|
||
|
Includes relevant abstract base classes (with their respective bases) from
|
||
|
the *types* iterable. Uses a modified C3 linearization algorithm.
|
||
|
|
||
|
"""
|
||
|
bases = set(cls.__mro__)
|
||
|
# Remove entries which are already present in the __mro__ or unrelated.
|
||
|
def is_related(typ):
|
||
|
return (typ not in bases and hasattr(typ, '__mro__')
|
||
|
and issubclass(cls, typ))
|
||
|
types = [n for n in types if is_related(n)]
|
||
|
# Remove entries which are strict bases of other entries (they will end up
|
||
|
# in the MRO anyway.
|
||
|
def is_strict_base(typ):
|
||
|
for other in types:
|
||
|
if typ != other and typ in other.__mro__:
|
||
|
return True
|
||
|
return False
|
||
|
types = [n for n in types if not is_strict_base(n)]
|
||
|
# Subclasses of the ABCs in *types* which are also implemented by
|
||
|
# *cls* can be used to stabilize ABC ordering.
|
||
|
type_set = set(types)
|
||
|
mro = []
|
||
|
for typ in types:
|
||
|
found = []
|
||
|
for sub in filter(_safe, typ.__subclasses__()):
|
||
|
if sub not in bases and issubclass(cls, sub):
|
||
|
found.append([s for s in sub.__mro__ if s in type_set])
|
||
|
if not found:
|
||
|
mro.append(typ)
|
||
|
continue
|
||
|
# Favor subclasses with the biggest number of useful bases
|
||
|
found.sort(key=len, reverse=True)
|
||
|
for sub in found:
|
||
|
for subcls in sub:
|
||
|
if subcls not in mro:
|
||
|
mro.append(subcls)
|
||
|
return _c3_mro(cls, abcs=mro)
|
||
|
|
||
|
|
||
|
def _safe(class_):
|
||
|
"""
|
||
|
Return if the class is safe for testing as subclass. Ref #2.
|
||
|
"""
|
||
|
return not getattr(class_, '__origin__', None)
|
||
|
|
||
|
|
||
|
def _find_impl(cls, registry):
|
||
|
"""Returns the best matching implementation from *registry* for type *cls*.
|
||
|
|
||
|
Where there is no registered implementation for a specific type, its method
|
||
|
resolution order is used to find a more generic implementation.
|
||
|
|
||
|
Note: if *registry* does not contain an implementation for the base
|
||
|
*object* type, this function may return None.
|
||
|
|
||
|
"""
|
||
|
mro = _compose_mro(cls, registry.keys())
|
||
|
match = None
|
||
|
for t in mro:
|
||
|
if match is not None:
|
||
|
# If *match* is an implicit ABC but there is another unrelated,
|
||
|
# equally matching implicit ABC, refuse the temptation to guess.
|
||
|
if (t in registry and t not in cls.__mro__
|
||
|
and match not in cls.__mro__
|
||
|
and not issubclass(match, t)):
|
||
|
raise RuntimeError("Ambiguous dispatch: {0} or {1}".format(
|
||
|
match, t))
|
||
|
break
|
||
|
if t in registry:
|
||
|
match = t
|
||
|
return registry.get(match)
|
||
|
|
||
|
def _validate_annotation(annotation):
|
||
|
"""Determine if an annotation is valid for registration.
|
||
|
|
||
|
An annotation is considered valid for use in registration if it is an
|
||
|
instance of ``type`` and not a generic type from ``typing``.
|
||
|
"""
|
||
|
try:
|
||
|
# In Python earlier than 3.7, the classes in typing are considered
|
||
|
# instances of type, but they invalid for registering single dispatch
|
||
|
# functions so check against GenericMeta instead.
|
||
|
from typing import GenericMeta
|
||
|
valid = not isinstance(annotation, GenericMeta)
|
||
|
except ImportError:
|
||
|
# In Python 3.7+, classes in typing are not instances of type.
|
||
|
valid = isinstance(annotation, type)
|
||
|
return valid
|
||
|
|
||
|
def singledispatch(func):
|
||
|
"""Single-dispatch generic function decorator.
|
||
|
|
||
|
Transforms a function into a generic function, which can have different
|
||
|
behaviours depending upon the type of its first argument. The decorated
|
||
|
function acts as the default implementation, and additional
|
||
|
implementations can be registered using the register() attribute of the
|
||
|
generic function.
|
||
|
"""
|
||
|
registry = {}
|
||
|
dispatch_cache = WeakKeyDictionary()
|
||
|
def ns(): pass
|
||
|
ns.cache_token = None
|
||
|
|
||
|
def dispatch(cls):
|
||
|
"""generic_func.dispatch(cls) -> <function implementation>
|
||
|
|
||
|
Runs the dispatch algorithm to return the best available implementation
|
||
|
for the given *cls* registered on *generic_func*.
|
||
|
|
||
|
"""
|
||
|
if ns.cache_token is not None:
|
||
|
current_token = get_cache_token()
|
||
|
if ns.cache_token != current_token:
|
||
|
dispatch_cache.clear()
|
||
|
ns.cache_token = current_token
|
||
|
try:
|
||
|
impl = dispatch_cache[cls]
|
||
|
except KeyError:
|
||
|
try:
|
||
|
impl = registry[cls]
|
||
|
except KeyError:
|
||
|
impl = _find_impl(cls, registry)
|
||
|
dispatch_cache[cls] = impl
|
||
|
return impl
|
||
|
|
||
|
def register(cls, func=None):
|
||
|
"""generic_func.register(cls, func) -> func
|
||
|
|
||
|
Registers a new implementation for the given *cls* on a *generic_func*.
|
||
|
|
||
|
"""
|
||
|
if func is None:
|
||
|
if isinstance(cls, type):
|
||
|
return lambda f: register(cls, f)
|
||
|
ann = getattr(cls, '__annotations__', {})
|
||
|
if not ann:
|
||
|
raise TypeError(
|
||
|
"Invalid first argument to `register()`: {cls!r}. "
|
||
|
"Use either `@register(some_class)` or plain `@register` "
|
||
|
"on an annotated function.".format(**locals())
|
||
|
)
|
||
|
func = cls
|
||
|
|
||
|
argname, cls = next(iter(get_type_hints(func).items()))
|
||
|
if not _validate_annotation(cls):
|
||
|
raise TypeError(
|
||
|
"Invalid annotation for {argname!r}. "
|
||
|
"{cls!r} is not a class.".format(**locals())
|
||
|
)
|
||
|
registry[cls] = func
|
||
|
if ns.cache_token is None and hasattr(cls, '__abstractmethods__'):
|
||
|
ns.cache_token = get_cache_token()
|
||
|
dispatch_cache.clear()
|
||
|
return func
|
||
|
|
||
|
def wrapper(*args, **kw):
|
||
|
if not args:
|
||
|
raise TypeError('{0} requires at least '
|
||
|
'1 positional argument'.format(funcname))
|
||
|
|
||
|
return dispatch(args[0].__class__)(*args, **kw)
|
||
|
|
||
|
funcname = getattr(func, '__name__', 'singledispatch function')
|
||
|
registry[object] = func
|
||
|
wrapper.register = register
|
||
|
wrapper.dispatch = dispatch
|
||
|
wrapper.registry = MappingProxyType(registry)
|
||
|
wrapper._clear_cache = dispatch_cache.clear
|
||
|
update_wrapper(wrapper, func)
|
||
|
return wrapper
|
||
|
|
||
|
|
||
|
# Descriptor version
|
||
|
class singledispatchmethod(object):
|
||
|
"""Single-dispatch generic method descriptor.
|
||
|
|
||
|
Supports wrapping existing descriptors and handles non-descriptor
|
||
|
callables as instance methods.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, func):
|
||
|
if not callable(func) and not hasattr(func, "__get__"):
|
||
|
raise TypeError("{!r} is not callable or a descriptor".format(func))
|
||
|
|
||
|
self.dispatcher = singledispatch(func)
|
||
|
self.func = func
|
||
|
|
||
|
def register(self, cls, method=None):
|
||
|
"""generic_method.register(cls, func) -> func
|
||
|
|
||
|
Registers a new implementation for the given *cls* on a *generic_method*.
|
||
|
"""
|
||
|
return self.dispatcher.register(cls, func=method)
|
||
|
|
||
|
def __get__(self, obj, cls=None):
|
||
|
def _method(*args, **kwargs):
|
||
|
method = self.dispatcher.dispatch(args[0].__class__)
|
||
|
return method.__get__(obj, cls)(*args, **kwargs)
|
||
|
|
||
|
_method.__isabstractmethod__ = self.__isabstractmethod__
|
||
|
_method.register = self.register
|
||
|
update_wrapper(_method, self.func)
|
||
|
return _method
|
||
|
|
||
|
@property
|
||
|
def __isabstractmethod__(self):
|
||
|
return getattr(self.func, '__isabstractmethod__', False)
|