Batch Normalization

Sources:

Batch Normalization

Given samples {xi} over a mini-batch: B={x1N}, the batch normalization is yi=f(xi,B)γ,β=γx^i+β=γxiμσ2+ϵ+β. where:

  • yi is the output of the batch normalization layer.
  • x^i is the normalized input xi
  • ϵ is a small constant added for avoiding division by 0 error.
  • γ and β are parameters learned during training for each feature, representing the scale and shift to be applied after normalization, respectively.

The detailed process is

$$ μB1mi=1mxi // mini-batch mean σB21m1i=1m(xiμB)2 // mini-batch variance x^ixiμBσB2+ϵ // normalize yiγx^i+βBNγ,β(xi) // scale and shift 

$$

Derivation of batch norm

We have: Lxi=j=1NLyjyjxi. Lyj is the upstream gradient, so it's already given here.

Since yj=γx^j+β. yj is a function of x^j. We have yjxi=yjx^jx^jxi, and yjx^j=γ. So yjxi=γx^jxi. This means we only need to calculate x^jxi.

Recall that x^i=xiμσ2+ϵ, we obtain: x^jxi=xi(xjμσ2+ε)=(δijμxi)(σ2+ε)12+(xjμ)(12)(σ2+ε)32σ2xi. where δij is the Kronecker delta function: δij={0 if ij1 if i=j.

Therefore, we compute the derivation μxi, σ2xi.

For the former: $$ μxi=xi(1N(x1+x2++xi++xN))=1N(0++1++0)=1N

$$

For the latter: $$ σ2xi=xi(1N1k=1N(xkμ)2)=1N1xi((x1μ)2++(xiμ)2+)=1N1[2(x1μ)(μ)xi+2(x2μ)(μ)xi++2(xiμ)(xiμ)xi+]=1N1[2(x1μ)1N+2(x2μ)1N++2(xiμ)(11N)+]=1N12[k=1N(xkμ0)(1N)+xiμ]=1N12(xiμ)=2N1(xiμ)

$$

The transition from 3rd line to 4th line is because μxi=1N.

So we have: $$ x^jxi=(δij1N)(σ2+ε)12+(xjμ)(12)(σ2+ε)322N1(xiμ)=(δij1N)(σ2+ε)121N1(σ2+ε)12xiμσ2+εxjμσ2+ε=(σ2+ε)12[δij1N1N1(x^ix^j)]=(σ2+ε)12N[Nδij1NN1(x^ix^j)].

$$

As a result: Lxi=j=1NLyjyjxi=j=1NLyjyjx^jx^jxi=j=1NLyjγ(σ2+ε)12N[Nδij1NN1(x^ix^j)]=γ(σ2+ε)12Nj=1NLyj[Nδij1NN1(x^ix^j)]=γ(σ2+ε)12N[(Nj=1NLyjδij)(j=1NLyj)(NN1xij=1NLyjx^j)]=γ(σ2+ε)12N(NLyij=1NLyjNN1x^ij=1NLyjx^j).