Delete any layer of ONNX
1. Environment
2. Procedure
$ python3 -m pip install onnx_graphsurgeon \
--index-url https://pypi.ngc.nvidia.com
import onnx_graphsurgeon as gs
import onnx
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--onnx_file_path", required=True, type=str)
parser.add_argument("--remove_node_name", required=True, type=str)
args = parser.parse_args()
graph = gs.import_onnx(onnx.load(args.onnx_file_path))
for i in graph.nodes:
print(i.name)
remove_node = [
node for node in graph.nodes if node.name == args.remove_node_name
][0]
# Get the input node of the fake node
# Node provides i() and o() functions that can optionally
# be provided an index (default is 0)
# These serve as convenience functions for the alternative,
# which would be to fetch the input/output
# tensor first, then fetch the input/output node of the tensor.
# For example, node.i() is equivalent to node.inputs[0].inputs[0]
inp_node = remove_node.i()
# Reconnect the input node to the output tensors of the fake node,
# so that the first identity node in the example graph now
# skips over the fake node.
inp_node.outputs = remove_node.outputs
remove_node.outputs.clear()
# Remove the fake node from the graph completely
graph.cleanup()
h = graph.inputs[0].shape[2]
w = graph.inputs[0].shape[3]
scale = 0
if graph.inputs[0].shape[1] == 4:
scale = 2
else:
scale = 3
graph.outputs[0].shape = [1,3,h*scale,w*scale]
print(graph.outputs)
onnx.save(gs.export_onnx(graph), args.onnx_file_path)
python3 remove_transpose.py \
--onnx_file_path saved_model_sony_240x320/model_float32.onnx \
--remove_node_name output__80
import onnx_graphsurgeon as gs
import onnx
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--onnx_file_path", required=True, type=str)
parser.add_argument("--remove_node_name", required=True, type=str)
args = parser.parse_args()
graph = gs.import_onnx(onnx.load(args.onnx_file_path))
remove_node = None
remove_node_idx = -1
for idx, node in enumerate(graph.nodes):
if node.name == args.remove_node_name:
remove_node = node
remove_node_idx = idx
break
graph.inputs[0].dtype = graph.nodes[remove_node_idx+1].inputs[0].dtype
graph.nodes[remove_node_idx+1].inputs[0] = graph.inputs[0]
remove_node.outputs.clear()
graph.cleanup()
onnx.save(gs.export_onnx(graph), args.onnx_file_path)
Reference
이 문제에 관하여(Delete any layer of ONNX), 우리는 이곳에서 더 많은 자료를 발견하고 링크를 클릭하여 보았다 https://zenn.dev/pinto0309/articles/8cb106569c9c3e텍스트를 자유롭게 공유하거나 복사할 수 있습니다.하지만 이 문서의 URL은 참조 URL로 남겨 두십시오.
우수한 개발자 콘텐츠 발견에 전념 (Collection and Share based on the CC Protocol.)