import inspect
import chainer
class WrappedFunctionNode(chainer.FunctionNode):
"""Wrap the target function and operate as ``FunctionNode``
Arguments:
name (str): name of the function node
func (func): the target function
args (list): args for the function
kwargs (dict): kwargs for the function
arg_vars (list): list of `chainer.Variable`s in `args` and `kwargs`
attributes (list): parameters to be set node's attributes
"""
def __init__(self, name, func, args, kwargs, arg_vars, attributes=None):
self.custom_function_node_name = name
self.func = func
self.args = args
self.kwargs = kwargs
self.arg_vars = arg_vars
self.internal_results = None
if attributes is not None:
for k, v in attributes.items():
setattr(self, k, v)
def forward(self, xs):
assert len(xs) == len(self.arg_vars)
self.xs = xs
results = self.func(*self.args, **self.kwargs)
if isinstance(results, (tuple, list)):
dummy_results = tuple(_unwrap_var(ret) for ret in results)
if all([_is_var(ret) for ret in results]):
self.internal_results = tuple(results)
elif isinstance(results, dict):
dummy_results = tuple(_unwrap_var(ret) for ret in results.values())
if all([_is_var(ret) for ret in results.values()]):
self.internal_results = tuple(results.values())
else:
dummy_results = _unwrap_var(results)
dummy_results = dummy_results,
if _is_var(results):
self.internal_results = results,
if not chainer.is_arrays_compatible(dummy_results):
raise ValueError(
'returned values from the function wrapped by \'as_funcnode\' '
'must consist only array, function name: {}'.format(self.name))
return dummy_results
def backward(self, target_input_indexes, grad_outputs):
if self.internal_results is None:
raise ValueError(
'the target function does not support backward, propagation '
'is failed')
grad_inputs = chainer.grad(self.internal_results, self.arg_vars,
grad_outputs=grad_outputs)
assert len(self.arg_vars) == len(grad_inputs)
return tuple(grad_input if i in target_input_indexes else None
for i, grad_input in enumerate(grad_inputs))
[docs]def fake_as_funcnode(alt_func, name, rename_attributes=None):
"""The target function fakes FunctionNode
The target function is replaced to the alternative function to connect
variable node by acting function node. ``alt_func`` must satisfy the
following restrictions.
1. Inputs includes one or more ``chainer.Variable`` to trace variables.
2. Output consists nothing but ``ndarray`` or ``chainer.Variable``
Even if ``alt_func`` returns ``ndarray``, the value forced to be converted
to ``chainer.Variable``. A caller of the target function have to care
both cases, returning ``ndarray`` and ``chainer.Variable``.
When ``alt_func`` returns ``list`` of variable, the wrapped function will
also returns multiple variables as ``tuple``. However ``dict`` cannot
be return, the wrapped function breaks down the returned values as
``tuple`` of values, keys will be ignored.
Arguments of ``alt_func`` except for ``chainer.Variable`` are set as
function attributes. Attribute names are set ``argN`` (N is index
number) or keyword on default.
Example:
>>> def func(x, a, b, c=1, d=2): pass
>>> # x is variable
>>> func = onnx_chainer.replace_func.fake_as_funcnode(
... func, 'CustomNode',
... rename_attributes=[(1, 'value'), ('c', 'y')])
Then ``func`` will be operated as a function node named "CustomNode", and
``'value'``, ``'b'``, ``'y'``, ``'d'`` are set as function's attributes.
See tests/test_replace_func.py more details.
Args:
alt_func (func): actual called function. There are some constrains, see
the above documentation.
name (str): function name. This name is used for what ONNX operator
to be assigned.
rename_attributes (list or tuple): rename attribute name, set list
of ``tuple(index_of_args, new_name)`` or
``tuple(kwargs_name, new_name)``
Returns:
func: wrapped function, called on exporting.
"""
def _wrapper(*args, **kwargs):
inputs = []
attributes = {}
rename_attr_dict = {}
if rename_attributes is not None:
rename_attr_dict = {attr[0]: attr[1] for attr in rename_attributes}
# resolve default value for kwargs
arg_spec = inspect.signature(alt_func)
bound = arg_spec.bind(*args, **kwargs)
bound.apply_defaults()
# default values are set on `bound.arguments`, but cannot get them
# from `bound.kwargs`
for i, (k, v) in enumerate(bound.arguments.items()):
if i < len(args):
continue
kwargs[k] = v
def set_attr(key, value):
default_name = key if isinstance(key, str) else 'arg{}'.format(key)
attributes[rename_attr_dict.get(key, default_name)] = value
def expand_args(args_iter):
for i, a in args_iter:
if _is_var(a):
inputs.append(a)
elif isinstance(a, (tuple, list)):
# all elements are variable -> add flatten them to inputs
# all elements are not variable -> add them to attributes
# mixed variable and other type value -> error
flatten_arg = _flatten(a)
var_or_not = map(_is_var, flatten_arg)
if all(var_or_not):
inputs.extend(flatten_arg)
elif not any(var_or_not):
set_attr(i, a)
else:
raise ValueError(
'arguments mixed variable and other type are not '
'supported')
else:
set_attr(i, a)
expand_args(enumerate(args))
expand_args(kwargs.items())
if not inputs:
raise ValueError(
'arguments of the function wrapped by \'as_funcnode\' '
'must include at least one chainer.Variable, function name: '
'{}'.format(name))
wrapped = WrappedFunctionNode(
name, alt_func, args, kwargs, inputs, attributes=attributes)
ret = wrapped.apply(inputs)
if len(ret) > 1:
return ret
return ret[0]
chainer.utils.experimental('as_funcnode')
return _wrapper
[docs]def as_funcnode(name, rename_attributes=None):
"""The target function fakes FunctionNode
The target function is overwrapped to connect variable node by acting
function node. Expected to be used as decorator. More detail, see
``fake_as_funcnode`` documentation.
Example:
>>> @onnx_chainer.replace_func.as_funcnode(
... 'CustomNode', rename_attributes=[(1, 'value'), ('c', 'y')])
... def func(x, a, b, c=1, d=2): pass
Args:
name (str): function name. This name is used for what ONNX operator
to be assigned.
rename_attributes (list or tuple): rename attribute name, set list
of ``tuple(index_of_args, new_name)`` or
``tuple(kwargs_name, new_name)``
"""
def _wrapper(fn):
return fake_as_funcnode(fn, name, rename_attributes=rename_attributes)
return _wrapper
def _unwrap_var(var):
return var.array if _is_var(var) else var
def _is_var(array):
# alias for type checking
return isinstance(array, chainer.Variable)
def _is_array(v):
return not isinstance(v, (list, tuple))
def _flatten(xs):
if _is_array(xs):
return [xs]
o = []
for x in xs:
if _is_array(x):
o.append(x)
else:
o.extend(_flatten(x))
return o