Source code for onnx_chainer.onnx_helper

import onnx


__func_name = None  # not care the name is unique on whole graph


[docs]def set_func_name(func_name): """Set the name of Chainer function being converted. Args: func_name (str): The name of Chainer function. """ global __func_name __func_name = func_name
[docs]def get_func_name(): """Return processing function name """ assert __func_name is not None return __func_name
[docs]def make_node(*args, **kwargs): """A thin wrapper of `onnx.helper.make_node`. Node name will be assigned automatically. Args: *args (tuple): ONNX node parameters of the node **kwargs (dict): ONNX attributes of the node. Returns: An `onnx.NodeProto` object. """ return onnx.helper.make_node(*args, name=get_func_name(), **kwargs)
[docs]class GraphBuilder(object): """A helper class to build consecutive ONNX nodes.""" def __init__(self): self._nodes = [] self._func_name = get_func_name()
[docs] def node_name(self): return '{}_tmp_{}'.format(self._func_name, len(self._nodes))
[docs] def op(self, op_name, input_names, num_outputs=1, **kwargs): """Creates a new ONNX node and returns its outputs. Args: op_name (str): The name of an ONNX op. input_names (list of str): The names of input values. num_outputs (int): The number of output values. **kwargs (dict): ONNX attributes of the node. Returns: A str of the output name when `num_outputs` is 1. A tuple of str of the output names otherwise. """ if num_outputs == 1: output_names = [self.node_name()] else: output_names = ['{}_{}'.format(self.node_name(), i) for i in range(num_outputs)] return self.op_output_named( op_name, input_names, output_names, **kwargs)
[docs] def op_output_named( self, op_name, input_names, output_names, **kwargs): """Creates a new ONNX node with output names, and returns its outputs. Args: op_name (str): The name of an ONNX op. input_names (list of str): The names of input values. output_names (int of str): The names of output values. **kwargs (dict): ONNX attributes of the node. Returns: A str of the output name when number of output is 1. A tuple of str of the output names otherwise. """ # Prevent a common mistake. `input_names="input"` creates a # node with 5 inputs. assert not isinstance(input_names, str) node = onnx.helper.make_node( op_name, input_names, output_names, name=self.node_name(), **kwargs) self._nodes.append(node) if len(output_names) == 1: return node.output[0] else: return tuple(node.output)
[docs] def nodes(self, output_names=None): """Returns all nodes created so far. Args: output_names (list of str): The names of output values to be set at the last node. Returns: A list of `onnx.NodeProto` objects, suitable as the return value of converter functions. """ if output_names is not None: assert len(self._nodes[-1].output) == len(output_names) self._nodes[-1].output[:] = output_names return tuple(self._nodes)
[docs]def write_tensor_pb(filename, name, value): with open(filename, 'wb') as f: t = onnx.numpy_helper.from_array(value, name) f.write(t.SerializeToString())
[docs]def cleanse_param_name(name): """Converts Chainer parameter names to ONNX names. Note ONNX identifiers must be a valid C identifier. Args: name (str): A Chainer parameter name (e.g., /l/W). Returns A valid ONNX name (e.g., param_l_W). """ return 'param' + name.replace('/', '_')
def is_support_non_standard_domain(): # from ONNX 1.5, skip schema check on ops in non-standard domain # see: https://github.com/onnx/onnx/pull/1876 # this checker expects onnx adapts semantic versioning versions = onnx.__version__.split('.') if len(versions) < 2 or (not versions[1].isdecimal()): raise RuntimeError( 'ONNX-Chainer cannot get major and minor version ONNX module: ' '{}'.format(onnx.__version__)) major, minor = versions[0], versions[1] return major == '1' and int(minor) >= 5