"""
Dendrite structure and computation variants.
"""
from pymonntorch import Behavior
import torch
[docs]class SimpleDendriteStructure(Behavior):
"""
Defines the Structure of the dendrite. Gathers currents for the Computation Behavior.
Args:
proximal_max_delay (int): Maximum delay of proximal dendrites. The default is 1. Set this to 0 to discard Proximal dendrite.
distal_max_delay (int): Maximum delay of distal dendrites. The default is 1. Set this to 0 to discard Distal dendrite.
apical_max_delay (int): Maximum delay of distal dendrites. The default is `distal_max_delay + 1`. Set this to 0 to discard Apical dendrite.
proximal_min_delay (int): Minimum delay of proximal dendrites. The default is 0.
distal_min_delay (int): Minimum delay of distal dendrites. The default is 0.
apical_min_delay (int): Minimum delay of apical dendrites. The default is `distal_min_delay + 1`.
"""
def __init__(
self,
*args,
Proximal_max_delay=1,
Distal_max_delay=1,
Apical_max_delay=None,
proximal_min_delay=0,
distal_min_delay=0,
apical_min_delay=None,
**kwargs,
):
super().__init__(
*args,
Proximal_max_delay=Proximal_max_delay,
Distal_max_delay=Distal_max_delay,
Apical_max_delay=Apical_max_delay,
proximal_min_delay=proximal_min_delay,
distal_min_delay=distal_min_delay,
apical_min_delay=apical_min_delay,
**kwargs,
)
[docs] def initialize(self, neurons):
self.proximal_max_delay = self.parameter("Proximal_max_delay", 1)
self.distal_max_delay = self.parameter("Distal_max_delay", 1)
self.apical_max_delay = self.parameter(
"Apical_max_delay", self.distal_max_delay + 1
)
self.proximal_min_delay = self.parameter("proximal_min_delay", 0)
if self.proximal_min_delay >= self.proximal_max_delay:
raise ValueError(
"proximal_min_delay should be smaller than proximal_max_delay"
)
self.distal_min_delay = self.parameter("distal_min_delay", 0)
if self.distal_min_delay >= self.distal_max_delay and self.distal_max_delay > 0:
raise ValueError("distal_min_delay should be smaller than distal_max_delay")
self.apical_min_delay = self.parameter(
"apical_min_delay", self.distal_min_delay + 1
)
if self.apical_min_delay >= self.apical_max_delay and self.apical_max_delay > 0:
raise ValueError("apical_min_delay should be smaller than apical_max_delay")
neurons.apical_input = [0]
if self.apical_max_delay:
neurons.apical_input = neurons.vector_buffer(self.apical_max_delay)
neurons.distal_input = [0]
if self.distal_max_delay:
neurons.distal_input = neurons.vector_buffer(self.distal_max_delay)
neurons.proximal_input = [0]
if self.proximal_max_delay:
neurons.proximal_input = neurons.vector_buffer(self.proximal_max_delay)
[docs] def update_min_delay(self, neurons):
if proximal_synapses := neurons.afferent_synapses.get("Proximal", []):
self.proximal_min_delay = torch.cat(
[synapse.dst_delay for synapse in proximal_synapses]
).min()
if distal_synapses := neurons.afferent_synapses.get("Distal", []):
self.distal_min_delay = torch.cat(
[synapse.dst_delay for synapse in distal_synapses]
).min()
if apical_synapses := neurons.afferent_synapses.get("Apical", []):
self.apical_min_delay = torch.cat(
[synapse.dst_delay for synapse in apical_synapses]
).min()
def _add_proximal(self, neurons, synapse):
neurons.proximal_input.scatter_add_(
0, synapse.dst_delay.unsqueeze(0), synapse.I.unsqueeze(0)
)
def _add_apical(self, neurons, synapse):
neurons.apical_input.scatter_add_(
0, synapse.dst_delay.unsqueeze(0), synapse.I.unsqueeze(0)
)
def _add_distal(self, neurons, synapse):
neurons.distal_input.scatter_add_(
0, synapse.dst_delay.unsqueeze(0), synapse.I.unsqueeze(0)
)
[docs] def forward(self, neurons):
if self.apical_max_delay:
neurons.apical_input = neurons.buffer_roll(
mat=neurons.apical_input, new=0, counter=True
)
if self.distal_max_delay:
neurons.distal_input = neurons.buffer_roll(
mat=neurons.distal_input, new=0, counter=True
)
if self.proximal_max_delay:
neurons.proximal_input = neurons.buffer_roll(
mat=neurons.proximal_input, new=0, counter=True
)
for synapse in neurons.afferent_synapses.get("Proximal", []):
self._add_proximal(neurons, synapse)
for synapse in neurons.afferent_synapses.get("Distal", []):
self._add_distal(neurons, synapse)
for synapse in neurons.afferent_synapses.get("Apical", []):
self._add_apical(neurons, synapse)
neurons.I_proximal = neurons.proximal_input[0]
neurons.I_apical = neurons.apical_input[0]
neurons.I_distal = neurons.distal_input[0]
[docs]class SimpleDendriteComputation(Behavior):
"""
Sums the different kind of dendrite entering the neurons.
Args:
apical_provocativeness (float): The strength of the apical dendrites. The default is None.
distal_provocativeness (float): The strength of the distal dendrites. The default is None.
I_tau (float): Decaying factor to current. If None, at each step, current falls to zero.
"""
def __init__(
self,
*args,
I_tau=None,
apical_provocativeness=None,
distal_provocativeness=None,
**kwargs,
):
super().__init__(
*args,
I_tau=I_tau,
apical_provocativeness=apical_provocativeness,
distal_provocativeness=distal_provocativeness,
**kwargs,
)
[docs] def initialize(self, neurons):
self.apical_provocativeness = self.parameter("apical_provocativeness", None)
self.distal_provocativeness = self.parameter("distal_provocativeness", None)
self.I_tau = self.parameter("I_tau", None)
neurons.I = neurons.vector()
def _calc_ratio(self, neurons, provocativeness):
provocative_limit = neurons.v_rest + provocativeness * (
neurons.threshold - neurons.v_rest
)
dv = torch.clip(provocative_limit - neurons.v, min=0)
return dv
[docs] def forward(self, neurons):
if self.I_tau is not None:
neurons.I -= neurons.I / self.I_tau
else:
neurons.I.fill_(0)
non_priming_apical = (
(
torch.tanh(neurons.I_apical)
* self._calc_ratio(neurons, self.apical_provocativeness)
)
if self.apical_provocativeness is not None
else 0
)
non_priming_distal = (
(
torch.tanh(neurons.I_distal)
* self._calc_ratio(neurons, self.distal_provocativeness)
)
if self.distal_provocativeness is not None
else 0
)
neurons.I += neurons.I_proximal + (
getattr(neurons, "tau", 1)
* (non_priming_apical + non_priming_distal)
/ getattr(neurons, "R", 1)
)