Source code for conex.behaviors.neurons.specs

"""
General specifications needed for spiking neurons.
"""

from pymonntorch import Behavior
import torch


[docs]class InherentNoise(Behavior): """ Applies noisy voltage to neurons in the population. Args: mode (str): Mode to be used in initialize the tensor. Accepts similar values to Pymonntorch's `tensor` function. Defaults to "rand". scale (float): Scale factor to multiply to the tensor. Default is 1.0. offset (function): An offset value to be added to the tensor. Default is 0.0. """ def __init__(self, *args, mode="rand", scale=1, offset=0, **kwargs): super().__init__(*args, mode=mode, scale=scale, offset=offset, **kwargs)
[docs] def initialize(self, neurons): self.mode = self.parameter("mode", "rand") self.scale = self.parameter("scale", 1) self.offset = self.parameter("offset", 0)
[docs] def forward(self, neurons): neurons.v += neurons.vector(mode=self.mode, scale=self.scale) + self.offset
[docs]class Fire(Behavior): """ Asks neurons to Fire. """
[docs] def forward(self, neurons): neurons.spiking_neuron.Fire(neurons)
[docs]class KWTA(Behavior): """ KWTA behavior of spiking neurons: if v >= threshold then v = v_reset and all other spiked neurons are inhibited. Note: Population should be built by NeuronDimension. and firing behavior should be added too. Args: k (int): number of winners. dimension (int, optional): K-WTA on specific dimension. defaults to None. """ def __init__(self, k, *args, dimension=None, **kwargs): super().__init__(*args, k=k, dimension=dimension, **kwargs)
[docs] def initialize(self, neurons): self.k = self.parameter("k", None, required=True) self.dimension = self.parameter("dimension", None) self.shape = (neurons.size, 1, 1) if hasattr(neurons, "depth"): self.shape = (neurons.depth, neurons.height, neurons.width)
[docs] def forward(self, neurons): will_spike = neurons.v >= neurons.threshold v_values = neurons.v dim = 0 if self.dimension is not None: v_values = v_values.view(self.shape) will_spike = will_spike.view(self.shape) dim = self.dimension if (will_spike.sum(axis=dim) <= self.k).all(): return _, k_winners_indices = torch.topk( v_values, self.k, dim=dim, sorted=False ) ignored = will_spike ignored.scatter_(dim, k_winners_indices, False) neurons.v[ignored.view((-1,))] = neurons.v_reset