Skip to content

Commit 88e9ebe

Browse files
committed
[PythonAPI] Adding input and output names options to the wraping function
Those two options can help for pytorch bad consistency in the naming in their ONNX exported layers. This fix allow the user to provide two list of names for the input and output of the network.
2 parents e1bed92 + 56ffa3f commit 88e9ebe

3 files changed

Lines changed: 15 additions & 9 deletions

File tree

docs/export/CPP_STM32.rst

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
Export: C++/STM32
22
=================
33

4-
**N2D2-IP only: available upon request.**
5-
64
Export type: ``CPP_STM32``
75
C++ export for STM32.
86

docs/quant/pruning.rst

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,16 @@ Example with Python
1515
:members:
1616
:inherited-members:
1717

18-
Example of code to use the *PruneCell* in your scripts:
18+
Example of code to use the :py:class:`n2d2.quantizer.PruneCell` in your scripts:
1919

2020
.. code-block:: python
2121
2222
for cell in model:
23-
### Add Pruning ###
24-
if isinstance(cell, n2d2.cells.Conv) or isinstance(cell, n2d2.cells.Fc):
25-
cell.quantizer = n2d2.quantizer.PruneCell(prune_mode="Static", threshold=0.3, prune_filler="IterNonStruct")
23+
### Add Pruning ###
24+
if isinstance(cell, n2d2.cells.Conv) or isinstance(cell, n2d2.cells.Fc):
25+
cell.quantizer = n2d2.quantizer.PruneCell(prune_mode="Static", threshold=0.3, prune_filler="IterNonStruct")
2626
27-
Some explanations with the differents options of the *PruneCell*:
27+
Some explanations with the differents options of the :py:class:`n2d2.quantizer.PruneCell` :
2828

2929
Pruning mode
3030
^^^^^^^^^^^^
@@ -42,7 +42,7 @@ For example, to update each two epochs, write:
4242
4343
n2d2.quantizer.PruneCell(prune_mode="Gradual", threshold=0.3, stepsize=2*DATASET_SIZE)
4444
45-
Where *DATASET_SIZE* is the size of the dataset you are using.
45+
Where ``DATASET_SIZE`` is the size of the dataset you are using.
4646

4747
Pruning filler
4848
^^^^^^^^^^^^^^
@@ -53,7 +53,7 @@ Pruning filler
5353
- IterNonStruct: all weights below than the ``delta`` factor are pruned. If this is not enough to reach ``threshold``, all the weights below 2 "delta" are pruned and so on...
5454

5555

56-
**Important**: With *PruneCell*, ``quant_mode`` and ``range`` are not used.
56+
**Important**: With :py:class:`n2d2.quantizer.PruneCell`, ``quant_mode`` and ``range`` are not used.
5757

5858

5959
Example with INI file

python/pytorch_to_n2d2/pytorch_interface.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,8 @@ def __exit__(self, exc_type, exc_value, traceback):
331331
def wrap(torch_model:torch.nn.Module,
332332
input_size: Union[list, tuple],
333333
opset_version:int=11,
334+
in_names:list=None,
335+
out_names:list=None,
334336
verbose:bool=False) -> Block:
335337
"""Function generating a ``torch.nn.Module`` which embed a :py:class:`n2d2.cells.DeepNetCell`.
336338
The torch_model is exported to N2D2 via ONNX.
@@ -341,6 +343,10 @@ def wrap(torch_model:torch.nn.Module,
341343
:type input_size: ``list``
342344
:param opset_version: Opset version used to generate the intermediate ONNX file, default=11
343345
:type opset_version: int, optional
346+
:param in_names: Specify specific names for the network inputs
347+
:type in_names: list, optional
348+
:param out_names: Specify specific names for the network outputs
349+
:type in_names: list, optional
344350
:param verbose: Enable the verbose output of torch onnx export, default=False
345351
:type verbose: bool, optional
346352
:return: A custom ``torch.nn.Module`` which embed a :py:class:`n2d2.cells.DeepNetCell`.
@@ -361,6 +367,8 @@ def wrap(torch_model:torch.nn.Module,
361367
dummy_in,
362368
raw_model_path,
363369
verbose=verbose,
370+
input_names=in_names,
371+
output_names=out_names,
364372
export_params=True,
365373
opset_version=opset_version,
366374
training=torch.onnx.TrainingMode.TRAINING,

0 commit comments

Comments
 (0)