Source code for conex.helpers.transforms.encoders

import torch


[docs]class SimplePoisson(torch.nn.Module): """ Simple Poisson encoding. Input values should be between 0 and 1. Spike rate is increased linearly with regard to the values. This transformer uses regular random generator provided for `torch.rand`. Args: time_window (int): The interval of the coding. ratio (float): A scale factor for probability of spiking. """ def __init__(self, time_window, ratio): self.time_window = time_window self.ratio = ratio def __call__(self, img): if type(img) is tuple: return tuple([self(sub_inp) for sub_inp in img]) random_probability = torch.rand(size=(self.time_window, *img.shape)) intensity = img.unsqueeze(dim=0).expand(self.time_window, *img.shape) spike_probability = intensity * self.ratio return spike_probability >= random_probability
[docs]class Poisson(torch.nn.Module): """ Poisson encoding. Input values should be between 0 and 1. The intervals between two spikes are picked using Poisson Distribution. Args: time_window (int): The interval of the coding. ratio (float): A scale factor for probability of spiking. """ def __init__(self, time_window, ratio): self.time_window = time_window self.ratio = ratio def __call__(self, img): if type(img) is tuple: return tuple([self(sub_inp) for sub_inp in img]) # https://github.com/BindsNET/bindsnet/blob/master/bindsnet/encoding/encodings.py original_shape, original_size = img.shape, img.numel() flat_img = img.view((-1,)) * self.ratio non_zero_mask = flat_img != 0 flat_img[non_zero_mask] = 1 / flat_img[non_zero_mask] dist = torch.distributions.Poisson(rate=flat_img, validate_args=False) intervals = dist.sample(sample_shape=torch.Size([self.time_window])) intervals[:, non_zero_mask] += (intervals[:, non_zero_mask] == 0).float() times = torch.cumsum(intervals, dim=0).long() times[times >= self.time_window + 1] = 0 spikes = torch.zeros( self.time_window + 1, original_size, device=img.device, dtype=torch.bool ) spikes[times, torch.arange(original_size, device=img.device)] = True spikes = spikes[1:] return spikes.view(self.time_window, *original_shape)
[docs]class Intensity2Latency(torch.nn.Module): """ Intensity to latency encoding. Stronger values spikes sooner. Args: time_windows (int): The interval of the coding. threshold (float): If not None, values lower than threshold will not spike. sparsity (float): If not None, defines a threshold for each input based on sparsity. min_val (float): Minimum possible value of input. The default is 0.0. max_val (float): Maximum possible value of input. The default is 1.0. lower_trim (bool): If True, spikes are transformed in order to have the last spike on the end of the interval. The default is True. higher_trim (bool): If True, spikes are transformed in order to have the first spike on the first of the interval. The default is True. """ def __init__( self, time_window, threshold=None, sparsity=None, min_val=0.0, max_val=1.0, lower_trim=True, higher_trim=True, ): self.time_window = time_window self.threshold = threshold if threshold is not None else min_val self.sparsity = sparsity self.interval = (min_val, max_val) self.higher_trim = higher_trim self.lower_trim = lower_trim def __call__(self, img): if type(img) is tuple: return tuple([self(sub_inp) for sub_inp in img]) self.threshold = ( img.quantile(1 - self.sparsity) if self.sparsity is not None else self.threshold ) below_index = img < self.threshold img -= self.interval[0] max_value = self.interval[1] - self.interval[0] max_factor = 1 / max_value if self.lower_trim and not below_index.all(): img_min = img[~below_index].min() img = img - img_min max_factor = 1 / (max_value - img_min) if self.higher_trim and img.max() != 0: max_factor = 1 / img.max() index = img * max_factor * (self.time_window - 1) index = index.ceil().long() index += 1 index[below_index] = 0 index = index.clamp(0) index = index.unsqueeze(0) spikes = torch.zeros( self.time_window + 1, *img.shape, dtype=torch.bool, device=img.device ) spikes.scatter_(0, index, True) spikes = spikes[1:] return spikes.flip(0)