Source code for conex.behaviors.synapses.specs

"""
Synapse-related behaviors.
"""

from pymonntorch import Behavior
import random
import torch


[docs]class SynapseInit(Behavior): """ This Behavior makes initial variable required for multiple behavior to use. **WARNING:** ``src_delay`` and ``dst_delay`` have equal delay for all of their neurons. And should be initialized by other behaviors. """
[docs] def initialize(self, synapse): if synapse.src is not None: synapse.src_shape = (1, 1, synapse.src.size) if hasattr(synapse.src, "depth"): synapse.src_shape = ( synapse.src.depth, synapse.src.height, synapse.src.width, ) synapse.src_delay = synapse.tensor( mode="zeros", dim=(1,), dtype=torch.long ).expand(synapse.src.size) if synapse.dst is not None: synapse.dst_shape = (1, 1, synapse.dst.size) if hasattr(synapse.dst, "depth"): synapse.dst_shape = ( synapse.dst.depth, synapse.dst.height, synapse.dst.width, ) synapse.dst_delay = synapse.tensor( mode="zeros", dim=(1,), dtype=torch.long ).expand(synapse.dst.size)
[docs]class DelayInitializer(Behavior): """ Initialize the delay of axon entering the synapse or the dendrite conveying current. delays (Tensor(int), optional): a tensor of delay for each neuron of source or destination. Args: mode (str or number): string should be torch functions that fills a tensor like: "random", "normal", "zeros", "ones", ... . In number case the synapse delays will be filled with that number. offset (int): delay added to the all delays. scale (int): scales delay. weights (tensor): giving the delays directly. destination (boolean): True for destination neurons. defaults to False. """ def __init__( self, *args, mode=None, scale=1, offset=0, destination=False, delays=None, **kwargs, ): super().__init__( *args, mode=mode, scale=scale, offset=offset, destination=destination, delays=delays, **kwargs, )
[docs] def initialize(self, synapse): """ Makes index for the Synapse delay. Args: synapses (SynapseGroup): The synapses whose weight should be bound. """ init_mode = self.parameter("mode", None) delays = self.parameter("delays", None) scale = self.parameter("scale", 1) offset = self.parameter("offset", 0) isDestination = self.parameter("destination", False) neurons = synapse.src attribute = "src" if isDestination: neurons = synapse.dst attribute = "dst" if init_mode is not None and delays is None: delays = neurons.vector(mode=init_mode) delays *= scale delays += offset setattr(synapse, f"{attribute}_delay", delays.to(torch.long))
[docs]class WeightInitializer(Behavior): """ Initialize the weights of synapse. Note: Either `mode` or `weights` should be not None or else, the weight matrix will be None. Args: mode (str or number): string should be torch functions that fills a tensor like: "random", "normal", "zeros", "ones", ... . In number case the synapse weights will be filled with that number. scale (float): Scaling factor to apply on the weight. offset (float): An offset to add to the weight. function (callable): A function to apply on weight. density (flaot): The sparsity of weights. default is one. true_sparsity (bool) : If false, weights are created but have zero value. Defaults to True. weights (tensor): Optional parameter to specify the weights matrix directly. weight_shape (tuple): Optional parameter to specify the shape of the weights tensor. kernel_shape (tuple): Optional parameter to specify the shape of the kernel. """ def __init__( self, *args, mode=None, scale=1, offset=0, function=None, density=1, true_sparsity=True, weight_shape=None, kernel_shape=None, weights=None, **kwargs, ): super().__init__( *args, mode=mode, scale=scale, offset=offset, function=function, density=density, true_sparsity=true_sparsity, weight_shape=weight_shape, kernel_shape=kernel_shape, weights=weights, **kwargs, )
[docs] def initialize(self, synapse): init_mode = self.parameter("mode", None) scale = self.parameter("scale", 1) offset = self.parameter("offset", 0) function = self.parameter("function", None) density = self.parameter("density", 1) true_sparsity = self.parameter("true_sparsity", True) synapse.weights = self.parameter("weights", None) weight_shape = self.parameter("weight_shape", None) synapse.kernel_shape = self.parameter("kernel_shape", None) weight_shape = ( weight_shape if weight_shape is not None else synapse.matrix_dim() ) if init_mode is not None and synapse.weights is None: if not true_sparsity or density == 1: synapse.weights = synapse.tensor( mode=init_mode, dim=weight_shape, density=density ) else: n_row, n_col = synapse.matrix_dim() nnz = int(n_row * n_col * density) both_indices = torch.tensor( random.sample(range(n_row * n_col), nnz), device=synapse.device ) # TODO pytorch alternative dst_idx = both_indices % n_col src_idx = both_indices // n_col indices = torch.stack([src_idx, dst_idx]) values = synapse.tensor(mode=init_mode, dim=(nnz,)) synapse.weights = torch.sparse_coo_tensor( indices, values, synapse.matrix_dim() ) synapse.weights = synapse.weights.coalesce() synapse.weights = synapse.weights.to_sparse_csc() synapse.dst_idx = torch.arange( n_col, device=synapse.device ).repeat_interleave(synapse.weights.ccol_indices().diff()) synapse.src_idx = synapse.weights.row_indices() if function is not None: synapse.weights = function(synapse.weights) synapse.weights = synapse.weights * scale if synapse.weights.layout != torch.strided: # Pytorch should fix is_sparse synapse.weights.values()[:] += offset else: synapse.weights += offset
[docs]class WeightNormalization(Behavior): """ This Behavior normalize weights in order to assure each destination neuron has sum of its weight equal to ``norm``. Supporting `Simple`, `Local2d`, 'Conv2d'. Args: norm (int): Desired sum of weights for each neuron. """ def __init__(self, *args, norm=1, **kwargs): super().__init__(*args, norm=norm, **kwargs)
[docs] def initialize(self, synapse): self.norm = self.parameter("norm", 1) self.dims = [x for x in range(1, len(synapse.weights.shape))] if len(synapse.weights.shape) == 2: self.dims = [0] if len(synapse.weights.shape) == 3: self.dims = [2]
[docs] def forward(self, synapse): weights_sum = synapse.weights.sum(dim=self.dims, keepdim=True) weights_sum[weights_sum == 0] = 1 synapse.weights *= self.norm / weights_sum
[docs]class CurrentNormalization(Behavior): """ This Behavior normalize Current in order to assure each destination neuron maximum input current is eight equal to ``norm``. Supporting `Simple`, `Local2d`, 'Conv2d'. Args: norm (int): Desired maximum of current for each neuron. """ def __init__(self, *args, norm=1, **kwargs): super().__init__(*args, norm=norm, **kwargs)
[docs] def initialize(self, synapse): self.norm = self.parameter("norm", 1) self.dims = [x for x in range(1, len(synapse.weights.shape))] if len(synapse.weights.shape) == 2: self.dims = [0] if len(synapse.weights.shape) == 3: self.dims = [2]
[docs] def forward(self, synapse): weights_sum = synapse.weights.sum(dim=self.dims).view( -1, ) weights_sum[weights_sum == 0] = 1 normalized = self.norm / weights_sum synapse.I *= normalized.repeat_interleave( synapse.I.numel() // normalized.numel() )
[docs]class WeightClip(Behavior): """ Clip the synaptic weights in a range. Args: w_min (float): Minimum weight constraint. w_max (float): Maximum weight constraint. """ def __init__(self, *args, w_min=0, w_max=1, **kwargs): super().__init__(*args, w_min=w_min, w_max=w_max, **kwargs)
[docs] def initialize(self, synapse): """ Set weight constraint attributes to the synapses. Args: synapses (SynapseGroup): The synapses whose weight should be bound. """ self.w_min = self.parameter("w_min", 0) self.w_max = self.parameter("w_max", 1) assert 0 <= self.w_min < self.w_max, "Invalid Interval for Weight Clip"
[docs] def forward(self, synapses): """ Clip the synaptic weights in each time step. Args: synapses (SynapseGroup): The synapses whose weight should be bound. """ synapses.weights = torch.clip(synapses.weights, self.w_min, self.w_max)
[docs]class PreSpikeCatcher(Behavior): """ Get the spikes from pre synaptic neuron group and set as src_spike attribute for the synapse group. Note: Axon should be added to pre synaptice neuron group """ initialize_last = True
[docs] def initialize(self, synapse): synapse.pre_spike = synapse.src.axon.get_spike(synapse.src, synapse.src_delay)
[docs] def forward(self, synapse): synapse.pre_spike = synapse.src.axon.get_spike(synapse.src, synapse.src_delay)
[docs]class PostSpikeCatcher(Behavior): """ Get the spikes from post synaptic neuron group and set as dst_spike attribute for the synapse group. Note: Axon should be added to post synaptice neuron group """ initialize_last = True
[docs] def initialize(self, synapse): synapse.post_spike = synapse.dst.axon.get_spike(synapse.dst, synapse.dst_delay)
[docs] def forward(self, synapse): synapse.post_spike = synapse.dst.axon.get_spike(synapse.dst, synapse.dst_delay)
[docs]class PreTrace(Behavior): """ Calculates the pre synaptic spike trace. Note : should be placed after spike catcher behavior. Args: tau_s (float): decay term for spike trace. The default is None. spike_scale (float): the increase effect of spikes on the trace. """ def __init__(self, tau_s, *args, spike_scale=1.0, **kwargs): super().__init__(*args, tau_s=tau_s, spike_scale=spike_scale, **kwargs)
[docs] def initialize(self, synapse): """ Sets the trace attribute for the neural population. """ self.tau_s = self.parameter("tau_s", None, required=True) self.spike_scale = self.parameter("spike_scale", 1.0) synapse.pre_trace = synapse.src.vector(0.0)
[docs] def forward(self, synapse): """ Calculates the spike trace of each neuron by adding current spike and decaying the trace so far. """ synapse.pre_trace += synapse.pre_spike * self.spike_scale synapse.pre_trace -= (synapse.pre_trace / self.tau_s) * synapse.network.dt
[docs]class PostTrace(Behavior): """ Calculates the post synaptic spike trace. Note : should be placed after spike catcher behavior. Args: tau_s (float): decay term for spike trace. The default is None. spike_scale (float): the increase effect of spikes on the trace. """ def __init__(self, tau_s, *args, spike_scale=1.0, **kwargs): super().__init__(*args, tau_s=tau_s, spike_scale=spike_scale, **kwargs)
[docs] def initialize(self, synapse): """ Sets the trace attribute for the neural population. """ self.tau_s = self.parameter("tau_s", None, required=True) self.spike_scale = self.parameter("spike_scale", 1.0) synapse.post_trace = synapse.dst.vector(0.0)
[docs] def forward(self, synapse): """ Calculates the spike trace of each neuron by adding current spike and decaying the trace so far. """ synapse.post_trace += synapse.post_spike * self.spike_scale synapse.post_trace -= (synapse.post_trace / self.tau_s) * synapse.network.dt