Variance regularized SSL maps neatly onto classical learning rules

Consider a one-layer neural network with N input units and M output units. Let x(t)RNbe the input to the network at time t, WRM×N be the feedforward matrix to be learned, a(t)=Wx(t)RM be the pre-activations, and zi(t)=f(ai(t)) the activity of the ith output neuron at time t.

Variance term

Lvar(t)=i=1MReLU(1σzi)=i=1MReLU(1α(zi(t)z¯i)2+Φ+ϵ)

where Φ represents the sum of terms like (zi(t)z¯i)2 for all other samples in a minibatch, and α the corresponding averaging factor (1B1 for a minibatch size of B). Assuming we can estimate the variance, and hence the standard deviation, online with a slow-moving filter, we set Φ=0 so that the term under the square root is now the one-sample contribution to an estimate of the variance of the unit's activity (assuming also that a reliable estimate of the mean activity z¯i is available).

Lvar(t)Wij=αΘ(1σzi)σzi(zi(t)z¯i)f(ai(t))xj(t)

where N is the Heaviside function. Importantly, σzi should be a reliable estimate of the standard deviation of the unit activities calculated over a sufficiently long timescale reflecting responses to several diverse inputs (this corresponds to estimating the mean and standard deviation over a set of inputs in a minibatch). However, the contribution of the current sample z(t) to the estimate of the gradient Lvar(t)Wij does not change except for the scaling factor α.

With the understanding that zi¯ and σzi are long-term estimates of the mean and standard deviations of the output activities, we will drop the superscript assuming all quantities to correspond to the current time step unless specified otherwise.

LvarWij=αΘ(1σzi)σzi(ziz¯i)f(ai)xj

Log variance loss is equivalent to Oja's rule

Consider a simpler alternative functional form for the variance regularization loss, namely the log variance of the output activity.

Lvar(t)=i=1Mlog(σzi2)=i=1Mlog(α2((zi(t)z¯i)2+Φ+ϵ))

which yields the gradient

Lvar(t)Wij=α2σzi2(zi(t)z¯i)f(ai(t))xj(t)

We now consider the case of a single output neuron (M=1), with no nonlinearity (f(a)=1), along with the assumption that the input is zero-centered (x¯j=0). Consequently, z¯=jWjx¯j=0 and σz2=(zz¯)2=z2, which yields a very simple update rule for the variance term as:

ΔWj=LvarWij=α2zxjz2

This update rule along with a weight decay (with coefficient η) yields a learning rule that, on average, is equivalent to Oja's rule upto a scaling factor, and in fact has exactly the same fixed points if ηα2=1.

ΔWj=α2zxjz2ηWjΔWj=α2zxjz2ηWj=α2z2(zxjηα2Wjz2)

Oja's rule

ΔWjOja=zxjWjz2ΔWjOja=zxjWjz2

Invariance term

The invariance term is simply the squared L2 distance between the output activities in two consecutive time steps, which can be expressed as the sum of unit-wise squared differences across time.

Lpull(t)=12z(t)SG(z(t1))2=12i=1M(zi(t)SG(zi(t1)))2

Here SG is the stopgrad function, reflecting the fact that we do not evaluate the gradient with respect to quantities in the past. This gives us the gradient

Lpull(t)Wij=(zi(t)zi(t1))f(ai(t))xj(t)

Dropping the superscript for the current time step t,

LpullWij=(zizi(t1))f(ai)xj

Covariance term

The covariance objective is sum of the squared off-diagonal terms of the covariance matrix between units.

Ldecorr=βi=1Mki(ziz¯i)2(zkz¯k)2LdecorrWij=β(ziz¯i)f(ai)xjki(zkz¯k)2

Here, β=1M1 is a scaling factor that keeps the objective invariant to the number of units in the population. The sum is over all other units' variance estimate, and represents a non-local unit-specific quantity. However, we could make a useful approximation ki(zk(2)z¯k)2kM(zk(2)z¯k)2 which turns this sum into a population-level measure that is common to all units in the population, and could be seen as a global (within a given sub-population) third-factor.

Total Loss

Combining the three gradients, we can write the weight updates in a single-layer VICReg model as

ΔWij=LpullWijλ1LvarWijλ2LdecorrWij

where λ1 and λ2 are loss coefficients, that have also here absorbed the scaling factors α and β.

ΔWij=((ziz(t1)i)+λ1Θ(1σzi)σzi(ziz¯i)λ2(ziz¯i)ki(zkz¯k)2)f(ai)xj
Links

Sources