Tensorflow 회전 ONNX

2731 단어 Coooding
참고 글:
https://github.com/onnx/tutorials/blob/master/tutorials/OnnxTensorflowExport.ipynb
https://www.jianshu.com/p/8ec3a6c9c453
TensorFlow 프레임 워 크 트 레이 닝 모델 을 사용 하고 ONNX 형식 으로 내 보 내 려 면 다음 과 같은 몇 가지 절 차 를 거 쳐 야 합 니 다.
훈련(훈련)
PB 파일 돌리 기(그래프 동결)
모델 형식 변환(모델 변환)
훈련(훈련)
TF model 을 ONNX model 로 성공 적 으로 전환 하기 위해 서 는 3 가지 정 보 를 준비 해 야 합 니 다.
1.TF 의 그림 정의,즉 네트워크 토폴로지 정보 만 포함 합 니 다(가중치 정보 포함 하지 않 음).가 져 오 는 방법 은 inference 코드 에 다음 코드 를 삽입 하여 출력 하 는 것 입 니 다.
with open("net.proto", "wb") as file:
    graph = tf.get_default_graph().as_graph_def(add_shapes=True)
    file.write(graph.SerializeToString())

2.형상 정보.기본 상황 에서 asgraph_def()방법 은 모양 정 보 를 내 보 내지 않 습 니 다.지정 한 add 를 통 해shapes 인 자 는 True 로 출력 을 강제 합 니 다.방법 은 상기 코드 를 참고 하 십시오.
3.파일 검사(checkpoint,ckpt).즉,우리 가 평소에 말 하 는 가중치 파일 입 니 다.TensorFlow 는 보통 checkpoint 파일 로 내 보 냅 니 다.
한 마디 로 하면 상기 몇 단 계 를 통 해 우 리 는 네트워크 토폴로지 정보 와 모양 정 보 를 기록 하 는*.proto 파일,그리고 가중치 의 checkpoint 파일 을 기록 할 수 있 습 니 다.
PB 파일 돌리 기(그래프 동결)
위 와 같이 일반적인 TF 의 모델 내 보 내기 방법 은 네트워크 정보 와 가중치 정 보 를 분리 하여 서로 다른 파일 에 저장 하 는데 이것 은 배치 할 때 편리 하지 않다.공식 적 으로 Freeze Graph 방식 을 제공 하여 모델 관련 정 보 를 모두*.pb 파일 에 포장 하 는 데 사용 합 니 다.
정부 에서 관련 도 구 를 제공 하 였 습 니 다 freezegraph,일반적으로 TensorFlow 를 설치 한 후 사용자 PATH 에 해당 하 는 bin 디 렉 터 리 에 자동 으로 추 가 됩 니 다.찾 지 못 하면 TensorFlow 소스 코드 tensorflow/python/tools/free 로 이동 할 수 있 습 니 다.graph.py 이 위 치 를 찾 아 보 거나 명령 행 을 통 해 module 를 가 져 오 는 방식 으로 호출 합 니 다.
예 를 들 어 다음 과 같이 출력 노드 가 여러 개 있 으 면 쉼표 로 구분 합 니 다.
# 1.    
freeze_graph --input_graph=/home/mnist-tf/graph.proto \
    --input_checkpoint=/home/mnist-tf/ckpt/model.ckpt \
    --output_graph=/tmp/frozen_graph.pb \
    --output_node_names=fc2/add \
    --input_binary=True

# 2.     module   
python -m tensorflow.python.tools.freeze_graph \
    --input_graph=my_checkpoint_dir/graphdef.pb \
    --input_binary=true \
    --output_node_names=output \
    --input_checkpoint=my_checkpoint_dir \
    --output_graph=tests/models/fc-layers/frozen.pb

그 중에서 가장 곤란 한 점 은 출력 노드 의 노드 이름 을 정확하게 알 아야 한 다 는 것 이다.나의 방법 은 tf.get 을 통 해default_graph().as_graph_def().node 는 각 노드 정 보 를 얻 은 다음 에 구체 적 인 출력 노드 이름 을 봅 니 다.
print([tensor.name for tensor in tf.get_default_graph().as_graph_def().node])

모델 형식 변환(모델 변환)
쓰다https://github.com/onnx/tensorflow-onnxtf2onx 도 구 를 제공 합 니 다.예 는 다음 과 같 습 니 다.
python -m tf2onnx.convert\
    --input tests/models/fc-layers/frozen.pb\
    --inputs X:0\
    --outputs output:0\
    --output tests/models/fc-layers/model.onnx\
    --verbose

모델 입 출력 이름 은 nodename:port_id 의 형식 입 니 다.그렇지 않 으 면 나중에 오류 가 발생 할 수 있 습 니 다.
큰 성 과 를 거두다.

좋은 웹페이지 즐겨찾기