Skip to content

Commit 413b3fe

Browse files
committed
first impl of linear regression synaptic ca gives a 4x speedup and learning performance is consistently _improved_ relative to prior, across ra25, deep_fsa, and objrec, which usually means that it is genuinely better. Lots more work to be done to explore the space but this is an encouraging start!
1 parent b8cac23 commit 413b3fe

21 files changed

Lines changed: 125 additions & 65 deletions

axon/enumgen.go

Lines changed: 43 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

axon/learn.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -682,9 +682,26 @@ func (ls *LRateParams) Init() {
682682
ls.UpdateEff()
683683
}
684684

685+
// SynCaFuns are different ways of computing synaptic calcium (experimental)
686+
type SynCaFuns int32 //enums:enum
687+
688+
const (
689+
// StdSynCa uses standard synaptic calcium integration method
690+
StdSynCa SynCaFuns = iota
691+
692+
// LinearSynCa uses linear regression generated calcium integration (much faster)
693+
LinearSynCa
694+
695+
// NeurSynCa uses simple product of separately-integrated neuron values (much faster)
696+
NeurSynCa
697+
)
698+
685699
// TraceParams manages parameters associated with temporal trace learning
686700
type TraceParams struct {
687701

702+
// how to compute the synaptic calcium (experimental)
703+
SynCa SynCaFuns
704+
688705
// time constant for integrating trace over theta cycle timescales -- governs the decay rate of syanptic trace
689706
Tau float32 `default:"1,2,4"`
690707

@@ -696,9 +713,12 @@ type TraceParams struct {
696713

697714
// rate = 1 / tau
698715
Dt float32 `view:"-" json:"-" xml:"-" edit:"-"`
716+
717+
pad, pad1, pad2 float32
699718
}
700719

701720
func (tp *TraceParams) Defaults() {
721+
tp.SynCa = LinearSynCa
702722
tp.Tau = 1
703723
tp.SubMean = 0
704724
tp.LearnThr = 0

axon/pathparams.go

Lines changed: 31 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -277,8 +277,7 @@ func (pj *PathParams) GatherSpikes(ctx *Context, ly *LayerParams, ni, di uint32,
277277
// DoSynCa returns false if should not do synaptic-level calcium updating.
278278
// Done by default in Cortex, not for some other special pathway types.
279279
func (pj *PathParams) DoSynCa() bool {
280-
if pj.PathType == RWPath || pj.PathType == TDPredPath || pj.PathType == VSMatrixPath ||
281-
pj.PathType == DSMatrixPath || pj.PathType == VSPatchPath || pj.PathType == BLAPath {
280+
if pj.Learn.Trace.SynCa != StdSynCa || pj.PathType == RWPath || pj.PathType == TDPredPath || pj.PathType == VSMatrixPath || pj.PathType == DSMatrixPath || pj.PathType == VSPatchPath || pj.PathType == BLAPath || pj.Learn.Hebb.On.IsTrue() {
282281
return false
283282
}
284283
return true
@@ -338,28 +337,36 @@ func (pj *PathParams) DWtSyn(ctx *Context, syni, si, ri, di uint32, layPool, sub
338337
// Uses synaptically integrated spiking, computed at the Theta cycle interval.
339338
// This is the trace version for hidden units, and uses syn CaP - CaD for targets.
340339
func (pj *PathParams) DWtSynCortex(ctx *Context, syni, si, ri, di uint32, layPool, subPool *Pool, isTarget bool) {
341-
// credit assignment part
342-
caUpT := SynCaV(ctx, syni, di, CaUpT) // time of last update
343-
syCaM := SynCaV(ctx, syni, di, CaM) // fast time scale
344-
syCaP := SynCaV(ctx, syni, di, CaP) // slower but still fast time scale, drives Potentiation
345-
syCaD := SynCaV(ctx, syni, di, CaD) // slow time scale, drives Depression (one trial = 200 cycles)
346-
pj.Learn.KinaseCa.CurCa(ctx.SynCaCtr, caUpT, &syCaM, &syCaP, &syCaD) // always update, getting current Ca (just optimization)
347-
348-
rb0 := NrnV(ctx, ri, di, SpkBin0)
349-
sb0 := NrnV(ctx, si, di, SpkBin0)
350-
rb1 := NrnV(ctx, ri, di, SpkBin1)
351-
sb1 := NrnV(ctx, si, di, SpkBin1)
352-
rb2 := NrnV(ctx, ri, di, SpkBin2)
353-
sb2 := NrnV(ctx, si, di, SpkBin2)
354-
rb3 := NrnV(ctx, ri, di, SpkBin3)
355-
sb3 := NrnV(ctx, si, di, SpkBin3)
356-
357-
b0 := 0.1 * (rb0 * sb0)
358-
b1 := 0.1 * (rb1 * sb1)
359-
b2 := 0.1 * (rb2 * sb2)
360-
b3 := 0.1 * (rb3 * sb3)
361-
362-
pj.Learn.KinaseCa.FinalCa(b0, b1, b2, b3, &syCaM, &syCaP, &syCaD)
340+
var syCaM, syCaP, syCaD, caUpT float32
341+
switch pj.Learn.Trace.SynCa {
342+
case StdSynCa:
343+
caUpT = SynCaV(ctx, syni, di, CaUpT) // time of last update
344+
syCaM = SynCaV(ctx, syni, di, CaM) // fast time scale
345+
syCaP = SynCaV(ctx, syni, di, CaP) // slower but still fast time scale, drives Potentiation
346+
syCaD = SynCaV(ctx, syni, di, CaD) // slow time scale, drives Depression (one trial = 200 cycles)
347+
pj.Learn.KinaseCa.CurCa(ctx.SynCaCtr, caUpT, &syCaM, &syCaP, &syCaD) // always update, getting current Ca (just optimization)
348+
case LinearSynCa:
349+
rb0 := NrnV(ctx, ri, di, SpkBin0)
350+
sb0 := NrnV(ctx, si, di, SpkBin0)
351+
rb1 := NrnV(ctx, ri, di, SpkBin1)
352+
sb1 := NrnV(ctx, si, di, SpkBin1)
353+
rb2 := NrnV(ctx, ri, di, SpkBin2)
354+
sb2 := NrnV(ctx, si, di, SpkBin2)
355+
rb3 := NrnV(ctx, ri, di, SpkBin3)
356+
sb3 := NrnV(ctx, si, di, SpkBin3)
357+
358+
b0 := 0.1 * (rb0 * sb0)
359+
b1 := 0.1 * (rb1 * sb1)
360+
b2 := 0.1 * (rb2 * sb2)
361+
b3 := 0.1 * (rb3 * sb3)
362+
363+
pj.Learn.KinaseCa.FinalCa(b0, b1, b2, b3, &syCaM, &syCaP, &syCaD)
364+
case NeurSynCa:
365+
gain := float32(1.0)
366+
syCaM = gain * NrnV(ctx, si, di, CaSpkM) * NrnV(ctx, ri, di, CaSpkM)
367+
syCaP = gain * NrnV(ctx, si, di, CaSpkP) * NrnV(ctx, ri, di, CaSpkP)
368+
syCaD = gain * NrnV(ctx, si, di, CaSpkD) * NrnV(ctx, ri, di, CaSpkD)
369+
}
363370

364371
SetSynCaV(ctx, syni, di, CaM, syCaM)
365372
SetSynCaV(ctx, syni, di, CaP, syCaP)

axon/rand.go

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package axon
22

33
import (
44
"cogentcore.org/core/vgpu/gosl/slrand"
5-
"cogentcore.org/core/vgpu/gosl/sltype"
65
)
76

87
//gosl:hlsl axonrand
@@ -32,8 +31,7 @@ func GetRandomNumber(index uint32, counter slrand.Counter, funIndex RandFunIndex
3231
var randCtr slrand.Counter
3332
randCtr = counter
3433
randCtr.Add(uint32(funIndex))
35-
var ctr sltype.Uint2
36-
ctr = randCtr.Uint2()
34+
ctr := randCtr.Uint2()
3735
return slrand.Float(&ctr, index)
3836
}
3937

axon/shaders/Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# The go generate command does this automatically.
33

44
all:
5-
cd ../; gosl -exclude=Update,UpdateParams,Defaults,AllParams,ShouldShow cogentcore.org/core/math32/v2/fastexp.go cogentcore.org/core/etable/v2/minmax ../chans/chans.go ../chans ../kinase ../fsfffb/inhib.go ../fsfffb github.com/emer/emergent/v2/etime github.com/emer/emergent/v2/ringidx rand.go avgmax.go neuromod.go globals.go context.go neuron.go synapse.go pool.go layervals.go act.go act_prjn.go inhib.go learn.go layertypes.go layerparams.go deep_layers.go rl_layers.go pvlv_layers.go pcore_layers.go prjntypes.go prjnparams.go deep_prjns.go rl_prjns.go pvlv_prjns.go pcore_prjns.go hip_prjns.go gpu_hlsl
5+
cd ../; gosl -exclude=Update,UpdateParams,Defaults,AllParams,ShouldShow cogentcore.org/core/math32/fastexp.go cogentcore.org/core/math32/minmax ../chans/chans.go ../chans ../kinase ../fsfffb/inhib.go ../fsfffb github.com/emer/emergent/v2/etime github.com/emer/emergent/v2/ringidx rand.go avgmax.go neuromod.go globals.go context.go neuron.go synapse.go pool.go layervals.go act.go act_prjn.go inhib.go learn.go layertypes.go layerparams.go deep_layers.go rl_layers.go pvlv_layers.go pcore_layers.go prjntypes.go prjnparams.go deep_prjns.go rl_prjns.go pvlv_prjns.go pcore_prjns.go hip_prjns.go gpu_hlsl
66

77
# note: gosl automatically compiles the hlsl files using this command:
88
%.spv : %.hlsl

axon/shaders/gpu_dwt.spv

7.86 KB
Binary file not shown.

axon/shaders/gpu_dwtfmdi.spv

172 Bytes
Binary file not shown.

axon/shaders/gpu_dwtsubmean.spv

188 Bytes
Binary file not shown.

axon/shaders/gpu_gather.spv

172 Bytes
Binary file not shown.

axon/shaders/gpu_newstate_pool.spv

172 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)