Source code for conex.nn.utils.replication

import copy
from pymonntorch import NeuronGroup, SynapseGroup, NetworkObject, Behavior, Network
from typing import Dict, List, Tuple, Callable
from conex.nn.structure.port import Port
from conex.nn.structure.io_layer import InputLayer, OutputLayer
import json

ELEMENTAL_STRUCTURE = [NeuronGroup, SynapseGroup, InputLayer, OutputLayer]

[docs]def object_hook(json_dict): def load_python_callables(raw_dict): for k, v in raw_dict.items(): if isinstance(v, dict): load_python_callables(v) if isinstance(v, list) and len(v) > 0 and v[0] == "python_callable": import_string = f"import {v[1]}" exec(import_string) raw_dict[k] = eval(f"{v[1]}.{v[2]}") load_python_callables(json_dict) return json_dict
[docs]class ExtraCallableJSONEncoder(json.JSONEncoder):
[docs] def default(self, object): if isinstance(object, Callable): return ("python_callable", object.__module__, object.__name__) super().default(object)
[docs]def get_all_required_structures(struc: NetworkObject) -> List[NetworkObject]: net = struc.network all_required_sub_structures = [struc.required_helper()] while all_required_sub_structures[-1]: current_struc = [] for x in all_required_sub_structures[-1]: if x.__class__ not in ELEMENTAL_STRUCTURE: current_struc.extend(current_struc.required_helper()) all_required_sub_structures.append(current_struc) flatten = [x for xs in all_required_sub_structures for x in xs] if len(flatten) != len(set(flatten)): print("ERROR: duplicate elements returned by required_helpers") return sorted(flatten, key=lambda x: net.Structures.index(x))
[docs]def save_ports( ports: Dict[str, Tuple[dict, List[Port]]], save_behavior_tag: bool, save_behavior_priority: bool, ) -> dict: result = { k: ( v[0], [ ( x[0], x[1], behaviors_to_list( x[2], save_behavior_priority=save_behavior_priority, save_behavior_tag=save_behavior_tag, ), ) for x in v[1] ], ) for k, v in ports.items() } return result
[docs]def build_ports(ports: Dict[str, Tuple[dict, List[Port]]]) -> Dict[str, tuple]: result = {} for key, v in ports.items(): result[key] = ( v[0], [Port(x[0], x[1], build_behavior_dict(x[2])) for x in v[1]], ) return result
[docs]def save_structure( struc: NetworkObject, save_device: bool = False, built_structures: dict = None, save_structure_tag: bool = False, save_behavior_tag: bool = True, save_behavior_priority: bool = True, all_structures_required: list = None, ) -> dict: result = {"tag": None} result["class"] = struc.__class__ result["behavior"] = behaviors_to_list( struc.behavior, save_behavior_tag=save_behavior_tag, save_behavior_priority=save_behavior_priority, ) if save_device: result["device"] = struc.device if save_structure_tag: result["tag"] = ",".join(struc.tags) if all_structures_required is None: result["built_structures"] = {} built_structures = result["built_structures"] all_structures_required = get_all_required_structures(struc) if isinstance(struc, NeuronGroup): result["size"] = struc.size elif isinstance(struc, SynapseGroup): result["src"] = ( all_structures_required.index(struc.src) if struc.src is not None else None ) result["dst"] = ( all_structures_required.index(struc.dst) if struc.src is not None else None ) elif isinstance(struc, InputLayer) or isinstance(struc, OutputLayer): print("ERROR: InputLayer and OutputLayer can't be saved.") else: for sub_struc in all_structures_required: if all_structures_required.index(sub_struc) not in built_structures: built_sub_struc = save_structure( sub_struc, all_structures_required=all_structures_required, built_structures=built_structures, save_structure_tag=save_structure_tag, save_behavior_tag=save_behavior_tag, save_behavior_priority=save_behavior_priority, save_device=save_device, ) built_structures[all_structures_required.index(sub_struc)] = ( built_sub_struc ) helper_dictionary = struc.save_helper(all_structures_required) result.update(helper_dictionary) return result
[docs]def behaviors_to_list( behavior: Dict[int, Behavior], save_behavior_tag: bool = True, save_behavior_priority: bool = True, ) -> List[dict]: result = [] for key, value in behavior.items(): behave_dict = {} if save_behavior_priority: behave_dict["key"] = key behave_dict["class"] = value.__class__ args = { int(k[4:]): val for k, val in value.init_kwargs.items() if k.startswith("arg_") } behave_dict["parameters_args"] = [args[i] for i in sorted(list(args.keys()))] behave_dict["parameters_kwargs"] = { k: val for k, val in value.init_kwargs.items() if not k.startswith("arg_") } if not save_behavior_tag: _ = behave_dict["parameters_kwargs"].pop("tag", "") result.append(behave_dict) return result
[docs]def build_behavior_dict(behavior_list: List[dict]) -> Dict[int, Behavior]: result = {} for i, beh_dict in enumerate(behavior_list): key = beh_dict.get("key", i) value = beh_dict["class"]( *beh_dict["parameters_args"], **beh_dict["parameters_kwargs"] ) result[key] = value return result
[docs]def create_structure_from_dict( net: Network, structure_dict: dict, built_structures: dict = None ) -> NetworkObject: struc_dict = copy.deepcopy(structure_dict) if struc_dict["class"] == NeuronGroup: return NeuronGroup( struc_dict["size"], net=net, behavior=build_behavior_dict(struc_dict["behavior"]), tag=struc_dict["tag"], ) if struc_dict["class"] == SynapseGroup: return SynapseGroup( net=net, src=( struc_dict["src"] if struc_dict["src"] is None else built_structures[struc_dict["src"]] ), dst=( struc_dict["dst"] if struc_dict["dst"] is None else built_structures[struc_dict["dst"]] ), behavior=build_behavior_dict(struc_dict["behavior"]), tag=struc_dict["tag"], ) if struc_dict["class"] == InputLayer: print("ERROR: recived InputLayer, Which cannot be saved.") return None if struc_dict["class"] == OutputLayer: print("ERROR: recived OutputLayer, Which cannot be saved.") return None if built_structures is None: built_structures = {} required_structures = struc_dict.pop("built_structures", {}) for k, v in required_structures.items(): if k not in built_structures: built_structures[k] = create_structure_from_dict(net, v, built_structures) else: print("ERROR: found duplicate in required structure.") class_obj = struc_dict.pop("class") _ = class_obj.build_helper(struc_dict, built_structures) # for extra in []: # helper_dict.pop(extra) # build_helper might intentionally change the behavior. struc_dict["behavior"] = build_behavior_dict(struc_dict["behavior"]) return class_obj(net=net, **struc_dict)
[docs]def replicate(obj: NetworkObject, net: Network) -> NetworkObject: save_dict = save_structure( obj, save_device=True, built_structures=None, save_structure_tag=True, save_behavior_tag=True, save_behavior_priority=True, all_structures_required=None, ) new_dict = copy.deepcopy(save_dict) struc = create_structure_from_dict( net=net, structure_dict=new_dict, built_structures=None ) return struc, save_dict
[docs]def save_structure_dict_to_json(struc_dict: dict, filename:str): with open(filename, "a+") as output_file: json.dump(copy.deepcopy(struc_dict), output_file, cls=ExtraCallableJSONEncoder, indent=2)
[docs]def load_structure_dict_from_json(filename: str): with open(filename, "r") as input_file: struc_dict = json.load(input_file, object_hook=object_hook) struc_dict["built_structures"] = {int(k):v for k,v in struc_dict["built_structures"].items()} return struc_dict