Delete any layer of ONNX

13019 단어 onnxtech

1. Environment

  • Python
  • onnx_graphsurgeon
  • https://github.com/cchen156/Learning-to-See-in-the-Dark
  • https://github.com/cchen156/Learning-to-See-in-the-Dark

    2. Procedure


    $ python3 -m pip install onnx_graphsurgeon \
    --index-url https://pypi.ngc.nvidia.com
    
    import onnx_graphsurgeon as gs
    import onnx
    import argparse
    
    parser = argparse.ArgumentParser()
    parser.add_argument("--onnx_file_path", required=True, type=str)
    parser.add_argument("--remove_node_name", required=True, type=str)
    args = parser.parse_args()
    
    graph = gs.import_onnx(onnx.load(args.onnx_file_path))
    
    for i in graph.nodes:
        print(i.name)
    
    remove_node = [
        node for node in graph.nodes if node.name == args.remove_node_name
    ][0]
    
    # Get the input node of the fake node
    # Node provides i() and o() functions that can optionally
    # be provided an index (default is 0)
    # These serve as convenience functions for the alternative,
    # which would be to fetch the input/output
    # tensor first, then fetch the input/output node of the tensor.
    # For example, node.i() is equivalent to node.inputs[0].inputs[0]
    inp_node = remove_node.i()
    
    # Reconnect the input node to the output tensors of the fake node,
    # so that the first identity node in the example graph now
    # skips over the fake node.
    inp_node.outputs = remove_node.outputs
    remove_node.outputs.clear()
    
    # Remove the fake node from the graph completely
    graph.cleanup()
    
    h = graph.inputs[0].shape[2]
    w = graph.inputs[0].shape[3]
    
    scale = 0
    if graph.inputs[0].shape[1] == 4:
        scale = 2
    else:
        scale = 3
    
    graph.outputs[0].shape = [1,3,h*scale,w*scale]
    print(graph.outputs)
    
    onnx.save(gs.export_onnx(graph), args.onnx_file_path)
    
    python3 remove_transpose.py \
    --onnx_file_path saved_model_sony_240x320/model_float32.onnx \
    --remove_node_name output__80
    
    import onnx_graphsurgeon as gs
    import onnx
    import argparse
    
    parser = argparse.ArgumentParser()
    parser.add_argument("--onnx_file_path", required=True, type=str)
    parser.add_argument("--remove_node_name", required=True, type=str)
    args = parser.parse_args()
    
    graph = gs.import_onnx(onnx.load(args.onnx_file_path))
    remove_node = None
    remove_node_idx = -1
    for idx, node in enumerate(graph.nodes):
        if node.name == args.remove_node_name:
            remove_node = node
            remove_node_idx = idx
            break
    graph.inputs[0].dtype = graph.nodes[remove_node_idx+1].inputs[0].dtype
    graph.nodes[remove_node_idx+1].inputs[0] = graph.inputs[0]
    remove_node.outputs.clear()
    graph.cleanup()
    onnx.save(gs.export_onnx(graph), args.onnx_file_path)
    

    좋은 웹페이지 즐겨찾기