Tensorflow 에 사용자 정의 Op 추가
13797 단어 TensorFlow딥 러 닝 프레임 워 크딥 러 닝python
예비 기능
그러나 기 존의 라 이브 러 리 에 원 하 는 동작 이 포함 되 어 있 지 않 으 면 스스로 하 나 를 맞 출 수 있 습 니 다. 맞 춤 형 Op 가 기 존의 라 이브 러 리 를 호 환 할 수 있 도록 다음 과 같은 작업 을 해 야 합 니 다.
Op 인터페이스 정의
먼저 Tensorflow 를 설치 합 니 다. 이 명령
pip install tensorflow
을 설치 하거나 원본 코드 에 따라 최신 Tensorflow 를 컴 파일 하여 설치 할 수 있 습 니 다.사용자 정의 Op 은 tensorflow
소스 코드 를 수정 해 야 합 니 다. 먼저 TensorFlow
시스템 에 등록 하여 Op
인 터 페 이 스 를 정의 해 야 합 니 다. 등록 할 때 Op
의 이름, 입력 (유형 과 이름) 과 출력 (유형 과 이름), 그리고 필요 한 속성 에 대한 문서 설명 을 지정 해 야 합 니 다.직관 적 인 인식 을 가지 기 위해 간단 한
Op
을 만 드 는 것 을 예 로 들 자. 이 Op
는 두 가지 int32
유형 tensor
을 입력 으로 받 아들 이 고 이 두 가지 tensor
를 출력 하 는
과
의 유일한 차이 점 은 첫 번 째 요소 가 0
로 설정 되 어 있 는 것 이다. 파일 tensorflow/core/user_ops/my_add.cc
을 만 들 고 호출 REGISTER_OP
하 는 것 이다.매크로 는 Op 의 인 터 페 이 스 를 정의 합 니 다.#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
using namespace tensorflow;
REGISTER_OP("MyAdd")
.Input("x: int32")
.Input("y: int32")
.Output("z: int32")
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
c->set_output(0, c->input(0));
c->set_output(0, c->input(1));
return Status::OK();
});
#include "tensorflow/core/framework/op_kernel.h"
using namespace tensorflow;
class MyAddOp : public OpKernel {
public:
explicit MyAddOp(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override {
// Grab the input tensor
const Tensor& a = context->input(0);
const Tensor& b = context->input(1);
auto A = a.flat();
auto B = b.flat();
// Create an output tensor
Tensor* output_tensor = NULL;
OP_REQUIRES_OK(context, context->allocate_output(0, a.shape(),
&output_tensor));
auto output_flat = output_tensor->flat();
// Set all but the first element of the output tensor to 0.
const int N = A.size();
for (int i = 1; i < N; i++) {
output_flat(i) = A(i)+B(i);
}
output_flat(0) = 0;
}
};
REGISTER_KERNEL_BUILDER(Name("MyAdd").Device(DEVICE_CPU), MyAddOp);
cmake 컴 파일
그리고
tensorflow/core/user_ops/
디 렉 터 리 에서 다음 명령 을 실행 하여 파일 을 컴 파일 합 니 다 .so
.TF_CFLAGS=( $(python -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_compile_flags()))') )
TF_LFLAGS=( $(python -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_link_flags()))') )
g++ -std=c++11 -shared my_add.cc -o my_add.so -fPIC ${TF_CFLAGS[@]} ${TF_LFLAGS[@]} -O2
호스트
GCC
버 전이 5.X
라면 마지막 컴 파일 명령 을 다음 으로 바 꾸 십시오.g++ -std=c++11 -shared my_add.cc -o my_add.so -fPIC -D_GLIBCXX_USE_CXX11_ABI=0 ${TF_CFLAGS[@]} ${TF_LFLAGS[@]} -O2
다른 유형의 컴 파일 명령 은 point - net tf 에서 나 왔 습 니 다.op:
/usr/local/cuda-8.0/bin/nvcc tf_grouping_g.cu -o tf_grouping_g.cu.o -c -O2 -DGOOGLE_CUDA=1 -x cu -Xcompiler -fPIC
g++ -std=c++11 tf_grouping.cpp tf_grouping_g.cu.o -o tf_grouping_so.so -shared -fPIC -I /usr/local/lib/python2.7/dist-packages/tensorflow/include -I /usr/local/cuda-8.0/include -I /usr/local/lib/python2.7/dist-packages/tensorflow/include/external/nsync/public -lcudart -L /usr/local/cuda-8.0/lib64/ -L/usr/local/lib/python2.7/dist-packages/tensorflow -ltensorflow_framework -O2 -D_GLIBCXX_USE_CXX11_ABI=0
Python 3.5 컴 파일 스 크 립 트:
#/bin/bash
CUDA_ROOT=/usr/local/cuda-9.2
TF_ROOT=/home/user/.local/lib/python3.5/site-packages/tensorflow
/usr/local/cuda-9.2/bin/nvcc -std=c++11 -c -o tf_sampling_g.cu.o tf_sampling_g.cu -O2 -DGOOGLE_CUDA=1 -x cu -Xcompiler -fPIC
#TF 1.8
g++ -std=c++11 tf_sampling.cpp tf_sampling_g.cu.o -o tf_sampling_so.so -shared -fPIC -I ${TF_ROOT}/include -I ${CUDA_ROOT}/include -I ${TF_ROOT}/include/external/nsync/public -lcudart -L ${CUDA_ROOT}/lib64/ -L ${TF_ROOT} -ltensorflow_framework -O2 #-D_GLIBCXX_USE_CXX11_ABI=0
bazel 컴 파일
bazel build -c opt //tensorflow/core/user_ops:my_add.so
호스트
GCC
버 전이 5.X
라면 컴 파일 명령 을 다음 으로 바 꾸 십시오.bazel build -c opt --cxxopt="-D_GLIBCXX_USE_CXX11_ABI=0" //tensorflow/core/user_ops:my_add.so
이때 발생 한
my_add.so
은 tensorflow/bazel-bin/tensorflow/core/user_ops
디 렉 터 리 아래 에 있 습 니 다.다음은 필요 할 때 이
Op
를 사용 할 수 있 습 니 다.import tensorflow as tf
so_file = 'your_add_so_file_path/my_add.so'
class MyAddTest(tf.test.TestCase):
def testMyAdd(self):
my_add_module = tf.load_op_library(so_file)
with self.test_session():
result = my_add_module.my_add([5, 4, 3, 2, 1],[1, 2, 3, 4, 5])
self.assertAllEqual(result.eval(), [0, 6, 6, 6, 6])
if __name__ == "__main__":
#tf.test.main()
my_add_module = tf.load_op_library(so_file)
out = my_add_module.my_add([5, 4, 3, 2, 1],[1, 2, 3, 4, 5])
sess = tf.Session()
result = sess.run(out)
print(result)
#output [0, 6, 6, 6, 6]
참고:
이 내용에 흥미가 있습니까?
현재 기사가 여러분의 문제를 해결하지 못하는 경우 AI 엔진은 머신러닝 분석(스마트 모델이 방금 만들어져 부정확한 경우가 있을 수 있음)을 통해 가장 유사한 기사를 추천합니다:
EMNIST에서 알파벳 필기 인식EMNIST-letters를 배웠습니다. CODE: DEMO: — mbotsu (@mb_otsu) 은 2017년에 NIST가 공개한 데이터세트입니다. EMNIST ByClass: 814,255 characters. ...
텍스트를 자유롭게 공유하거나 복사할 수 있습니다.하지만 이 문서의 URL은 참조 URL로 남겨 두십시오.
CC BY-SA 2.5, CC BY-SA 3.0 및 CC BY-SA 4.0에 따라 라이센스가 부여됩니다.