(Caffe, LeNet) 가중치 업데이트 (7)
Solver:: ApplyUpdate () 함수 에서 역방향 전파 단계 에 따라 계 산 된 loss 가 네트워크 가중치 에 대한 편향 에 따라 설정 한 학습 전략 을 사용 하여 네트워크 가중치 를 업데이트 하여 이번 학습 을 완성 합 니 다.
1 모델 최적화
1.1 손실 함수
손실 함수 L (W) 은 경험 치 손실 에 정규 화 항목 을 추가 하여 얻 을 수 있 습 니 다. 다음 과 같 습 니 다. 그 중에서 X (i) 는 입력 견본 입 니 다.fW 는 특정한 견본 의 손실 함수 이다.N 은 mini - batch 의 견본 수량 입 니 다.r (W) 를 가중치 로 합 니 다.λ 의 정규 항.
L(W)≈1N∑NifW(X(i))+λr(W)
caffe 에서 세 단계 로 나 눌 수 있 습 니 다.
lenet 에서 solver 의 유형 은 SGD (Stochastic gradient descent) 입 니 다.
SGD 는 다음 과 같은 공식 을 통 해 가중치 를 업데이트 합 니 다.
Wt+1=Wt+Vt+1 Vt+1=μVt−α∇L(Wt)
그 중에서 Wt + 1 은 제 t + 1 라운드 의 가중치 이다.Vt + 1 은 t + 1 라운드 업데이트 (쓰기 도 가능)ΔWt+1 ); μ 지난 라운드 업 데 이 트 를 위 한 가중치;α 학습 율 을 위 하여;∇ L (Wt) 은 loss 의 가중치 에 대한 가이드 입 니 다.
2 코드 분석
2.1 ApplyUpdate
void SGDSolver<Dtype>::ApplyUpdate() {
// (learning rate)
Dtype rate = GetLearningRate();
//
// lenet , `conv1`,`conv2`,`ip1`,`ip2`
//
// `learnable_params_` size 8.
for (int param_id = 0; param_id < this->net_->learnable_params().size();
++param_id) {
// ,iter_size 1 , lenet 。
// , iter_size 1 iter_size
Normalize(param_id);
//
Regularize(param_id);
// \delta w
ComputeUpdateValue(param_id, rate);
}
//
this->net_->Update();
}
설명:
lenet_solver.prototxt
에서 찾 을 수 있 습 니 다
# The base learning rate, momentum and the weight decay of the network.
base_lr: 0.01
momentum: 0.9
weight_decay: 0.0005
# The learning rate policy
lr_policy: "inv"
gamma: 0.0001
power: 0.75
inv
의 전략 으로 교체 학 습 률 이 모두 바 뀌 지 않 는 전략 이다. // The learning rate decay policy. The currently implemented learning rate
// policies are as follows:
// - fixed: always return base_lr.
// - step: return base_lr * gamma ^ (floor(iter / step))
// - exp: return base_lr * gamma ^ iter
// - inv: return base_lr * (1 + gamma * iter) ^ (- power)
// - multistep: similar to step but it allows non uniform steps defined by
// stepvalue
// - poly: the effective learning rate follows a polynomial decay, to be
// zero by the max_iter. return base_lr (1 - iter/max_iter) ^ (power)
// - sigmoid: the effective learning rate follows a sigmod decay
// return base_lr ( 1/(1 + exp(-gamma * (iter - stepsize))))
//
// where base_lr, max_iter, gamma, step, stepvalue and power are defined
// in the solver parameter protocol buffer, and iter is the current iteration.
2.2 Regularize
이 함 수 는 실제 다음 과 같은 공식 을 집행 한다.
∂loss∂wij=decay∗wij+∂loss∂wij
코드 는 다음 과 같 습 니 다:
void SGDSolver::Regularize(int param_id) {
const vector *>& net_params = this->net_->learnable_params();
const vector<float>& net_params_weight_decay =
this->net_->params_weight_decay();
Dtype weight_decay = this->param_.weight_decay();
string regularization_type = this->param_.regularization_type();
// local_decay = 0.0005 in lenet
Dtype local_decay = weight_decay * net_params_weight_decay[param_id];
...
if (regularization_type == "L2") {
// axpy means ax_plus_y. i.e., y = a*x + y
caffe_axpy(net_params[param_id]->count(),
local_decay,
net_params[param_id]->cpu_data(),
net_params[param_id]->mutable_cpu_diff());
}
...
}
2.3 ComputeUpdateValue
이 함수 실제 실행 아래 공식 vij = lrrate∗∂loss∂wij+momentum∗vij ∂loss∂wij=vij
코드 는 다음 과 같 습 니 다:
void SGDSolver::ComputeUpdateValue(int param_id, Dtype rate) {
const vector *>& net_params = this->net_->learnable_params();
const vector<float>& net_params_lr = this->net_->params_lr();
// momentum = 0.9 in lenet
Dtype momentum = this->param_.momentum();
// local_rate = lr_mult * global_rate
// lr_mult , lenet_train_test.prototxt
Dtype local_rate = rate * net_params_lr[param_id];
// Compute the update to history, then copy it to the parameter diff.
...
// axpby means ax_plus_by. i.e., y = ax + by
// \delta w,
caffe_cpu_axpby(net_params[param_id]->count(), local_rate,
net_params[param_id]->cpu_diff(), momentum,
history_[param_id]->mutable_cpu_data());
// \delta w diff
caffe_copy(net_params[param_id]->count(),
history_[param_id]->cpu_data(),
net_params[param_id]->mutable_cpu_diff());
...
}
2.4 net_->Update
실제 실행 다음 공식: wij = wij + (− 1)
caffe_axpy(count_, Dtype(-1),
static_cast<const Dtype*>(diff_->cpu_data()),
static_cast(data_->mutable_cpu_data()));
참고 문헌:
[1]. http://caffe.berkeleyvision.org/tutorial/solver.html
이 내용에 흥미가 있습니까?
현재 기사가 여러분의 문제를 해결하지 못하는 경우 AI 엔진은 머신러닝 분석(스마트 모델이 방금 만들어져 부정확한 경우가 있을 수 있음)을 통해 가장 유사한 기사를 추천합니다:
Windows 10의 Ubuntu에 Caffe를 배치하고 Single Shot MultiBox Detector 실행예전에는 Ubuntu on Windows 10에서 카페가 LMDB의 동작에 문제가 있었지만 지금은 동작의 상태인 것 같다.저는 버클리 대학의 카페입니다. Intel은 자체 CPU에서 빠른 컴퓨팅을 위한 컴파일러 및 ...
텍스트를 자유롭게 공유하거나 복사할 수 있습니다.하지만 이 문서의 URL은 참조 URL로 남겨 두십시오.
CC BY-SA 2.5, CC BY-SA 3.0 및 CC BY-SA 4.0에 따라 라이센스가 부여됩니다.