Source code for conex.behaviors.neurons.axon
"""
Axon mechanisms for neurons.
"""
from pymonntorch import Behavior
import torch
[docs]class NeuronAxon(Behavior):
"""
Propagate the spikes and apply the delay mechanism.
Note: should be added after fire.
Args:
max_delay (int): Maximum delay of all dendrites connected to the neurons. This value determines the delay buffer size.
proximal_min_delay (int): Minimum delay of proximal dendrites. The default is 0.
distal_min_delay (int): Minimum delay of distal dendrites. The default is 0.
apical_min_delay (int): Minimum delay of apical dendrites. The default is 0.
"""
def __init__(
self,
*args,
max_delay=1,
proximal_min_delay=0,
distal_min_delay=0,
apical_min_delay=0,
**kwargs,
):
super().__init__(
*args,
max_delay=max_delay,
proximal_min_delay=proximal_min_delay,
distal_min_delay=distal_min_delay,
apical_min_delay=apical_min_delay,
**kwargs,
)
[docs] def initialize(self, neurons):
self.max_delay = self.parameter("max_delay", 1)
self.proximal_min_delay = self.parameter("proximal_min_delay", 0)
self.distal_min_delay = self.parameter("distal_min_delay", 0)
self.apical_min_delay = self.parameter("apical_min_delay", 0)
self.spike_history = neurons.vector_buffer(self.max_delay, dtype=torch.bool)
neurons.axon = self
[docs] def update_min_delay(self, neurons):
if proximal_synapses := neurons.efferent_synapses.get("Proximal", []):
self.proximal_min_delay = torch.cat(
[synapse.src_delay for synapse in proximal_synapses]
).min()
if distal_synapses := neurons.efferent_synapses.get("Distal", []):
self.distal_min_delay = torch.cat(
[synapse.src_delay for synapse in distal_synapses]
).min()
if apical_synapses := neurons.efferent_synapses.get("Apical", []):
self.apical_min_delay = torch.cat(
[synapse.src_delay for synapse in apical_synapses]
).min()
[docs] def get_spike(self, neurons, delay):
return self.spike_history.gather(0, delay.unsqueeze(0)).squeeze(0)
[docs] def forward(self, neurons):
self.spike_history = neurons.buffer_roll(
mat=self.spike_history, new=neurons.spikes
)