Source code for onnx_chainer.export

from __future__ import print_function

from collections import OrderedDict
import warnings

import chainer

try:
    import onnx
    from onnx import checker
    from onnx import helper
    from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE
    from onnx import numpy_helper
    from onnx import shape_inference

    from onnx_chainer.context import Context
    from onnx_chainer.graph import Graph
    from onnx_chainer import mapping
    from onnx_chainer.onnx_helper import is_support_non_standard_domain

    _available = True
except ImportError:
    _available = False

MINIMUM_OPSET_VERSION = 7
MAXIMUM_OPSET_VERSION = 11


def _check_available():
    if not _available:
        raise ImportError(
            'ONNX is not installed on your environment. Exporting your model '
            'in ONNX format needs the onnx package.\n\n'
            '\t$ pip install \'onnx<1.7.0\'\n\n')


def convert_parameter(parameter, context):
    if isinstance(parameter, chainer.Parameter):
        array = parameter.array
    elif isinstance(parameter, chainer.Variable):
        array = parameter.array
    elif isinstance(parameter, chainer.get_array_types()):
        array = parameter
    else:
        raise ValueError(
            'The type of parameter is unknown. It should be either Parameter '
            'or Variable or ndarray, but the type was {}.'.format(
                type(parameter)))
    array = chainer.cuda.to_cpu(array)
    tensor = numpy_helper.from_array(array, context.get_name(parameter))
    return tensor


def rename_variable_name(
        context, variables, named_vars, new_names, prefix='Input'):
    # Update ``named_vars`` keys to ``new_names``
    if isinstance(variables, (list, tuple)):
        if new_names is None:
            new_names = ['{}_{}'.format(prefix, i)
                         for i in range(len(named_vars))]
        if not isinstance(new_names, (list, tuple)) or\
                len(variables) != len(new_names):
            raise ValueError(
                'Replacing name list is not match with input (or output) '
                'variables')
        for i, var in enumerate(variables):
            del named_vars[context.get_name(var)]
            new_name = new_names[i]
            named_vars[new_name] = var
            context.set_name(var, new_name, pinned=True)
    elif isinstance(variables, dict):
        if new_names is None:
            new_names = {k: '{}_{}'.format(prefix, i)
                         for i, k in enumerate(variables.keys())}
        if not isinstance(new_names, (list, tuple, dict)) or\
                len(variables) != len(new_names):
            raise ValueError(
                'Replacing name dict is not match with input (or output) '
                'variables')
        if isinstance(new_names, (list, tuple)):
            new_names = {k: v for k, v in zip(variables.keys(), new_names)}
        for k, v in variables.items():
            if k not in new_names:
                raise ValueError(
                    'Key of replacing name is not found in variables')
            del named_vars[context.get_name(v)]
            new_name = new_names[k]
            named_vars[new_name] = v
            context.set_name(v, new_name, pinned=True)
    elif isinstance(variables, chainer.Variable):
        if not new_names:
            new_names = prefix + '_0'
        if isinstance(new_names, (list, tuple)):
            if len(new_names) != 1:
                raise ValueError('Replacing name must be single')
            new_name = new_names[0]
        elif isinstance(new_names, str):
            new_name = new_names
        else:
            raise ValueError(
                'Type {} is not supported for single variable'.format(
                    type(new_name)))
        del named_vars[context.get_name(variables)]
        named_vars[new_name] = variables
        context.set_name(variables, new_name, pinned=True)


def format_customized_shapes(args, shapes):
    if isinstance(args, (list, tuple)):
        if not isinstance(shapes, list) or len(args) != len(shapes):
            raise ValueError('Customized shapes cannot fit for input list')
        for i, (arg, shape) in enumerate(zip(args, shapes)):
            if len(arg.shape) != len(shape):
                raise ValueError(
                    'Index-{} shape length must be same as input'.format(i))
        return shapes
    elif isinstance(args, dict):
        if not isinstance(shapes, (list, dict)) or\
                len(args) != len(shapes):
            raise ValueError('Customized shapes cannot fit for input dict')
        if isinstance(shapes, list):
            shapes = {k: v for k, v in zip(args.keys(), shapes)}
        formatted_shapes = []
        for k, arg in args.items():
            if k not in shapes:
                raise ValueError(
                    'Key "{}" is not found in customized shapes'.format(k))
            if len(arg.shape) != len(shapes[k]):
                raise ValueError(
                    'Key "{}" shape length must be same as input'.format(k))
            formatted_shapes.append(shapes[k])
        return formatted_shapes
    else:
        assert isinstance(args, (chainer.Variable, chainer.get_array_types()))
        if isinstance(shapes, list):
            if len(shapes) != 1:
                raise ValueError('Customized shape must be single')
        elif not isinstance(shapes, tuple):
            raise ValueError(
                'Type {} is not supported for single input'.format(
                    type(shapes)))
        else:
            shapes = [shapes]
        if len(args.shape) != len(shapes[0]):
            raise ValueError('Shape length must be same as input')
        return shapes


class RetainInputHook(chainer.LinkHook):
    """Retain temporary inputs

    Function nodes manage inputs variable nodes using weak reference. When
    variable is made as temporary value, exporter cannot get the corresponded
    variable from the variable node because the reference is collected. To
    resolve it, retain all inputs and will use when make computational graph.

    To reduce memory size, this hook retains only variables not showed in link
    inputs. To enable this feature, links are required to use ``forward``, not
    ``__call__``.
    """

    def __init__(self):
        self.link_inputs = set()
        self.retain_inputs = []
        self.replaced_inputs = []

        self.org_apply = chainer.function_node.FunctionNode.apply

        def hooked_apply(_self, inputs):
            ret = self.org_apply(_self, inputs)
            func_inodes = list(_self.inputs)
            for i, inode in enumerate(func_inodes):
                referenced_var = inode.get_variable_or_none()
                if referenced_var is None:
                    # This variable is created within function node and weakref
                    # is lost. Make temporary variable and retain it.
                    temp_var = chainer.as_variable(inputs[i])
                    func_inodes[i] = temp_var.node
                    self.retain_inputs.append(temp_var)
                else:
                    if id(referenced_var) not in self.link_inputs:
                        # This variable is created within link forward, outside
                        # of function node. To avoid to lose reference out
                        # of the forward, retain the variable.
                        self.retain_inputs.append(referenced_var)
            self.replaced_inputs.append((_self, _self.inputs))
            _self.inputs = tuple(func_inodes)
            return ret
        self.hooked_apply = hooked_apply

    def _extract_inputs(self, args):
        # Retain only chainer.Variable (and its collection)
        # Other type args are ignored and not checked instance IDs
        # If these variable are used in FunctionNode, they will be retained
        ret = set()
        if isinstance(args, chainer.Variable):
            ret.add(id(args))
        elif isinstance(args, (list, tuple)):
            for arg in args:
                ret |= self._extract_inputs(arg)
        elif isinstance(args, dict):
            for arg in args.values():
                ret |= self._extract_inputs(arg)
        return ret

    def forward_preprocess(self, args):
        self.link_inputs |= self._extract_inputs(args.args)
        self.link_inputs |= self._extract_inputs(args.kwargs)

    def forward_postprocess(self, args):
        self.link_inputs.clear()

    def __enter__(self):
        chainer.function_node.FunctionNode.apply = self.hooked_apply
        return super().__enter__()

    def __exit__(self, *exc_details):
        chainer.function_node.FunctionNode.apply = self.org_apply
        for _self, inputs in self.replaced_inputs:
            _self.inputs = inputs
        super().__exit__(*exc_details)


[docs]def export(model, args, filename=None, export_params=True, graph_name='Graph', save_text=False, opset_version=None, input_names=None, output_names=None, train=False, return_named_inout=False, external_converters=None, external_opset_imports=None, input_shapes=None): """Export function for chainer.Chain in ONNX format. This function performs a forward computation of the given :class:`~chainer.Chain`, ``model``, by passing the given arguments ``args`` directly. It means, the output :class:`~chainer.Variable` object ``y`` to make the computational graph will be created by: ``y = model(*args)`` ``external_converters`` and ``external_opset_imports`` are for external custom operator. When some ~chainer.FunctionNode are expected to convert to own customized operator, set converter function with ~chainer.FunctionNode name. >>> import onnx >>> def custom_converter(param): ... return onnx.helper.make_node( ... 'CustomizedRelu', param.input_names, param.output_names, ... domain='chainer'), >>> >>> external_converters = {'ReLU': custom_converter} >>> external_imports = {'chainer': 0} >>> >>> model = chainer.Sequential(F.relu) # set the target model >>> args = chainer.Variable(np.random.rand(1,10)) # set dummy input >>> onnx_graph = onnx_chainer.export( ... model, args, ... external_converters=external_converters, ... external_opset_imports=external_imports) Returned model has ``CustomizedRelu`` node. Args: model (~chainer.Chain): The model object you want to export in ONNX format. It should have :meth:`__call__` method because the second argument ``args`` is directly given to the model by the ``[]`` accessor. args (list or dict): The arguments which are given to the model directly. filename (str or file-like object): The filename used for saving the resulting ONNX model. If None, nothing is saved to the disk. export_params (bool): If True, this function exports all the parameters included in the given model at the same time. If False, the exported ONNX model doesn't include any parameter values. graph_name (str): A string to be used for the ``name`` field of the graph in the exported ONNX model. save_text (bool): If True, the text format of the output ONNX model is also saved with ``.txt`` extention. opset_version (int): The operator set version of ONNX. If not specified or ``None`` is given, the latest opset version of the onnx module is used. If an integer is given, it will be ensured that all the operator version in the exported ONNX file is less than this value. input_names (str, list or dict): Customize input names of the graph. Number of ``input_names`` must be same as number of ``args``. When set dict type, keys must be same as ``args``'s keys. output_names (str, list or dict): Customize output name of the graph. Number of ``output_names`` must be same as actual outputs from ``model``. When set dict type, keys must be same as the key of ``model`` output. train (bool): If True, output computational graph with train mode. return_named_inout (bool): If set True, return ONNX model with named inputs, and named outputs. external_converters (dict): Add-on converter. Convert functions keyed by ~chainer.FunctionNode name. external_opset_imports (dict): Import external opset. opset version number keyed by domain name. input_shapes (tuple, list, dict): Input shape of output graph follows the customized shapes if set. When input are collection type, set list or dict. Tuple of tuple is not allowed. Returns: ~onnx.ModelProto or tuple: When ``return_named_inout`` is ``False``, return ModelProto as an ONNX model. Otherwise return the tuple of ModelProto, named inputs and outputs, both inputs and outputs are list of ~chainer.Variable. """ _check_available() with chainer.using_config('train', train),\ chainer.using_config('in_recomputing', True),\ chainer.using_config('enable_backprop', True): return _export( model, args, filename, export_params, graph_name, save_text, opset_version, input_names, output_names, return_named_inout, external_converters, external_opset_imports, input_shapes)
def _export(model, args, filename, export_params, graph_name, save_text, opset_version, input_names, output_names, return_named_inout, external_converters, external_opset_imports, input_shapes): if opset_version is None: opset_version = min( int(onnx.defs.onnx_opset_version()), MAXIMUM_OPSET_VERSION) elif opset_version < MINIMUM_OPSET_VERSION or \ opset_version > MAXIMUM_OPSET_VERSION: warnings.warn( 'ONNX-Chainer has been tested only with opset_version {} ~ {}' 'The ONNX file exported with your requested opset_version ({}) ' 'may cause some problems because the converters used for the ' 'opset_version have not been tested.'.format( MINIMUM_OPSET_VERSION, MAXIMUM_OPSET_VERSION, opset_version)) if input_shapes is not None: # if input shapes are invalid, raise exception before forwarding. input_shapes = format_customized_shapes(args, input_shapes) with RetainInputHook(): # Forward computation context = Context(model) network_inputs = OrderedDict() if isinstance(args, tuple): args = list(args) if isinstance(args, list): for i, arg in enumerate(args): if isinstance(arg, chainer.get_array_types()): args[i] = chainer.Variable(arg) network_inputs[context.get_name(args[i])] = args[i] outputs = model(*args) elif isinstance(args, dict): for key, arg in args.items(): if isinstance(arg, chainer.get_array_types()): args[key] = chainer.Variable(arg) network_inputs[context.get_name(args[key])] = args[key] outputs = model(**args) elif isinstance(args, chainer.get_array_types()): args = chainer.Variable(args) network_inputs[context.get_name(args)] = args outputs = model(args) elif isinstance(args, chainer.Variable): network_inputs[context.get_name(args)] = args outputs = model(args) else: raise ValueError( 'The \'args\' argument should be a list, tuple, dict, ' 'numpy array, or Chainer Variable. But a {} object was ' 'given.'.format(type(args))) rename_variable_name(context, args, network_inputs, input_names) initializers = [] input_tensors = [] param_names = set() for org_name, param in model.namedparams(): # `model.namedparams()` has `include_uninit` flag but not use, to # output user warning if param.array is None: warnings.warn( 'The parameter \'{}\' is not initialized, skip setting to ' 'ONNX graph'.format(org_name)) continue name = context.get_name(param) param_names.add(name) tensor = convert_parameter(param, context) initializers.append(tensor) input_tensors.append(helper.make_tensor_value_info( name, tensor.data_type, tensor.dims)) for i, (name, var) in enumerate(network_inputs.items()): shape = var.shape if input_shapes is None else input_shapes[i] input_tensors.append(helper.make_tensor_value_info( name, NP_TYPE_TO_TENSOR_TYPE[var.dtype], shape)) if external_converters: chainer.utils.experimental('external_converters') converters = dict(mapping.converters, **external_converters) else: converters = mapping.converters if isinstance(outputs, (list, tuple)): flat_outputs = outputs elif isinstance(outputs, dict): flat_outputs = list(outputs.values()) elif isinstance(outputs, chainer.Variable): flat_outputs = [outputs] else: raise RuntimeError( 'Unexpected output type from the model: {}'.format( type(outputs))) if not all([isinstance(o, chainer.Variable) for o in flat_outputs]): raise ValueError('The all \'outputs\' must be Chainer Variable') network_outputs = OrderedDict( [(context.get_name(var), var) for var in flat_outputs]) if output_names: rename_variable_name( context, outputs, network_outputs, output_names) o = Graph(context, converters, opset_version, param_names | set(network_inputs.keys()), network_outputs) o.to_onnx_graph() implicit_input_names = set(context.implicit_inputs.keys()) for name in implicit_input_names: tensor = convert_parameter(context.implicit_inputs[name], context) initializers.append(tensor) input_tensors.append(helper.make_tensor_value_info( name, tensor.data_type, tensor.dims)) # If additional parameters are created during conversion for param in context.parameters: tensor = convert_parameter(param, context) initializers.append(tensor) input_tensors.append(helper.make_tensor_value_info( context.get_name(param), tensor.data_type, tensor.dims)) # Convert output tensors output_tensors = [] for name, var in network_outputs.items(): output_tensors.append(helper.make_tensor_value_info( name, NP_TYPE_TO_TENSOR_TYPE[var.dtype], var.shape)) if not export_params: initializers = [] onnx_graph = helper.make_graph( o.graph, graph_name, input_tensors, output_tensors, initializer=initializers) opset_imports = [helper.make_operatorsetid('', opset_version)] if external_opset_imports: chainer.utils.experimental('external_opset_imports') for domain, version in external_opset_imports.items(): opset_imports.append(helper.make_operatorsetid(domain, version)) model = helper.make_model( onnx_graph, producer_name='Chainer', producer_version=chainer.__version__, opset_imports=opset_imports ) model.ir_version = onnx.IR_VERSION check_onnx_model(model, external_converters, external_opset_imports) if input_shapes is not None: for output in model.graph.output: for d in output.type.tensor_type.shape.dim: d.Clear() model = shape_inference.infer_shapes(model) check_onnx_model(model, external_converters, external_opset_imports) if filename is not None and isinstance(filename, str): with open(filename, 'wb') as fp: fp.write(model.SerializeToString()) if save_text: with open(filename + '.txt', 'w') as fp: print(model, file=fp) elif hasattr(filename, 'write'): filename.write(model.SerializeToString()) if return_named_inout: chainer.utils.experimental('return_named_inout') return model, network_inputs, network_outputs return model def check_onnx_model(onnx_model, external_converters, external_opset_imports): try: checker.check_model(onnx_model) except onnx.checker.ValidationError as e: if external_converters is None: raise e else: # ONNX version >= 1.5: default checker skips schema check when # non standard domain is set. In ONNX-Chainer, external ops without # doamin is also accepted, but show warning. # ONNX version < 1.5: the checker does not skip schema check # regardless domain is set or not. In ONNX-Chainer, ignore # errors when external ops are set. if is_support_non_standard_domain(): if external_opset_imports: raise e else: warnings.warn( 'ValidationError is occurred but ignored. ' 'ONNX-Chainer recommends to set ' '`external_opset_imports` when using ' '`external_converters` on exporting. Please take care ' 'about ONNX format check is insufficient. Error ' 'message:\n{}'.format(str(e)), UserWarning) else: warnings.warn( 'ValidationError is occurred but ignored because ' 'exporting with `external_converters`. Please take care ' 'about ONNX format check is insufficient. Error ' 'message:\n{}'.format(str(e)), UserWarning)