Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor MouseNet into an importable module #1

Draft
wants to merge 9 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions README.md

This file was deleted.

68 changes: 68 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
=========
Mouse_CNN
=========
This is the repo for the ongoing project of CNN MouseNet -- a convolutional neural network constrained by the architecture of the mouse visual cortex.

.. contents:: Table of Contents
:depth: 2

Folder Structure
================

::

Mouse_CNN/
├── mousenet/
│ │
│ │
│ ├── cmouse/ - Code related to constructing the PyTorch model
│ │ └── __init__.py
│ │
│ ├── example/ - Example code and resources
│ │
│ ├── mouse_cnn/ - Code related to deriving architecture from data
│ │ └── __init__.py
│ │
│ ├── retinotopics/ - Code related to calculating visual subfields
│ │ └── __init__.py
│ │
│ └── /
│ ├── loader.py - load function for loading a mousenet model with a particular initialization
│ └── __init__.py - manages pathing for saving models + logs
├── environment.yml - Conda environment and dependancies
├── setup.cfg - Development configurations and linting
├── setup.py - package definitions
└── tests/ - tests folder


Usage
=====

Installation:
Change directory to cloned folder

.. code-block::
$ pip install -e .



To load a mousenet model

.. code-block::

$ import mousenet
$ model = mousenet.load(architecture="stock", pretraining=None)

Architecture can be one of "stock" or "retinotopic" for visual subfields. Pretraining can be on of: None, "kaiming" for kaiming initialization, or "Imagenet" for imagenet pretraining.


To test the code

.. code-block::

$ pytest
1 change: 0 additions & 1 deletion cmouse/__init__.py

This file was deleted.

34 changes: 34 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
name: mousenet

channels:
- pytorch
- conda-forge

dependencies:
- python=3.8
- shapely
- numpy
- scipy
- matplotlib
- pandas
- pytest
- flake8
- networkx
- network
- rope
- autopep8
- black
- click
- jupyter
- pyyaml
- pytorch>=1.1.0
- torchvision
- cudatoolkit=10.0
- tensorboard>=1.14
- absl-py
- future
- tqdm
- pip
- pip:
- -e . # install package in development mode
- git+https://github.com/AllenInstitute/mouse_connectivity_models.git
46 changes: 46 additions & 0 deletions experiments/config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
name: Mnist_LeNet
save_dir: saved/
seed: 1234
target_devices: [0]

arch:
type: MnistModel
args: {}

augmentation:
type: MNISTTransforms
args: {}

data_loader:
type: MnistDataLoader
args:
batch_size: 128
data_dir: data/
nworkers: 2
shuffle: true
validation_split: 0.1

loss: nll_loss

lr_scheduler:
type: StepLR
args:
gamma: 0.1
step_size: 50

metrics:
- top_1_acc
- top_3_acc

optimizer:
type: Adam
args:
lr: 0.001
weight_decay: 0

training:
early_stop: 10
epochs: 100
monitor: min val_loss
save_period: 1
tensorboard: true
43 changes: 43 additions & 0 deletions logging.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
version: 1
disable_existing_loggers: False
formatters:
simple:
format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s"

handlers:
console:
class: logging.StreamHandler
level: DEBUG
formatter: simple
stream: ext://sys.stdout

debug_file_handler:
class: logging.handlers.RotatingFileHandler
level: DEBUG
formatter: simple
filename: debug.log
maxBytes: 10485760 # 10MB
backupCount: 10
encoding: utf8

info_file_handler:
class: logging.handlers.RotatingFileHandler
level: INFO
formatter: simple
filename: info.log
maxBytes: 10485760 # 10MB
backupCount: 10
encoding: utf8

error_file_handler:
class: logging.handlers.RotatingFileHandler
level: ERROR
formatter: simple
filename: error.log
maxBytes: 10485760 # 10MB
backupCount: 10
encoding: utf8

root:
level: INFO
handlers: [console, debug_file_handler, info_file_handler, error_file_handler]
1 change: 1 addition & 0 deletions mousenet/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .loader import load
File renamed without changes.
5 changes: 1 addition & 4 deletions cmouse/anatomy.py → mousenet/cmouse/anatomy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,7 @@
import networkx as nx
import matplotlib.pyplot as plt
import sys
sys.path.append('../')
sys.path.append('../../')
sys.path.append('../../../')
from mouse_cnn.architecture import *
from ..mouse_cnn.architecture import *
# from config import get_output_shrinkage

class AnatomicalLayer:
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
2 changes: 1 addition & 1 deletion cmouse/main.py → mousenet/cmouse/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def main_worker(gpu, ngpus_per_node, args):

if NET == 1:
net_name = 'network_complete_updated_number(%s,%s,%s)'%(INPUT_SIZE[0],INPUT_SIZE[1],INPUT_SIZE[2])
architecture = Architecture(data_folder=DATA_DIR)
architecture = Architecture()
net = gen_network(net_name, architecture)
if FIXMASK != 0:
np.random.seed(FIXMASK)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,44 @@
from copyreg import pickle
import torch
from torch import nn
import networkx as nx
import numpy as np
from config import INPUT_SIZE, EDGE_Z, OUTPUT_AREAS, HIDDEN_LINEAR, NUM_CLASSES
import pathlib, os
import pickle
from .exps.imagenet.config import INPUT_SIZE, EDGE_Z, OUTPUT_AREAS, HIDDEN_LINEAR, NUM_CLASSES
import pdb

def get_retinotopic_mask(layer, retinomap):
region_name = ''.join(x for x in layer.lower() if x.isalpha())
mask = torch.zeros(32, 32)
if layer == "input":
return
if region_name == "visp":
return 1

for area in retinomap:
area_name = area[0].lower()
if area_name == region_name:
normalized_polygon = area[1]
x, y = normalized_polygon.exterior.coords.xy
x, y = list(x), list(y)
xshift= yshift = int(0)
if area_name != "visp":
xshift = int((max(x) - min(x))/4)
yshift = int((max(y) - min(y))/4)
x1, x2 = int(max(min(x)+xshift, 0)), int(min(max(x) - xshift, 32))
y1, y2 = int(max(min(y) + yshift, 0)), int(min(max(y) - yshift, 32))
mask[x1:x2, y1:y2] = 1
mask_sum = mask.sum()
project_root = pathlib.Path(__file__).parent.parent.resolve()
file = os.path.join(project_root, "retinotopics", "mask_areas", f"{area_name}.pkl")
pickle.dump(mask_sum, open(file,"wb"))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
mask.to(device)
return mask

# raise ValueError(f"Could not find area for layer {layer} in retinomap")


class Conv2dMask(nn.Conv2d):
"""
Expand Down Expand Up @@ -66,15 +102,25 @@ class MouseNetCompletePool(nn.Module):
"""
torch model constructed by parameters provided in network.
"""
def __init__(self, network, mask=3):
def __init__(self, network, mask=3, retinomap=None):
super(MouseNetCompletePool, self).__init__()
self.Convs = nn.ModuleDict()
self.BNs = nn.ModuleDict()
self.network = network
# self.layer_masks = dict()
self.retinomap = retinomap

G, _ = network.make_graph()
self.top_sort = list(nx.topological_sort(G))


# if self.retinomap is not None:
# for layer in self.top_sort:
# self.layer_masks[layer] = get_retinotopic_mask(layer, self.retinomap)
# else:
# for layer in self.top_sort:
# self.layer_masks[layer] = torch.ones(32, 32)

for layer in network.layers:
params = layer.params
self.Convs[layer.source_name + layer.target_name] = Conv2dMask(params.in_channels, params.out_channels, params.kernel_size,
Expand Down Expand Up @@ -103,18 +149,18 @@ def __init__(self, network, mask=3):
# layer = network.find_conv_source_target('%s2/3'%area[:-1],'%s'%area)
# total_size += int(layer.out_size*layer.out_size*layer.params.out_channels)

self.classifier = nn.Sequential(
nn.Linear(int(total_size), NUM_CLASSES),
# self.classifier = nn.Sequential(
# nn.Linear(int(total_size), NUM_CLASSES),
# nn.Linear(int(total_size), HIDDEN_LINEAR),
# nn.ReLU(True),
# nn.Dropout(),
# nn.Linear(HIDDEN_LINEAR, HIDDEN_LINEAR),
# nn.ReLU(True),
# nn.Dropout(),
# nn.Linear(HIDDEN_LINEAR, NUM_CLASSES),
)
# )

def get_img_feature(self, x, area_list, flatten=True):
def get_img_feature(self, x, area_list, flatten=False):
"""
function for get activations from a list of layers for input x
:param x: input image set Tensor with size (num_img, INPUT_SIZE[0], INPUT_SIZE[1], INPUT_SIZE[2])
Expand All @@ -131,17 +177,40 @@ def get_img_feature(self, x, area_list, flatten=True):
if area == 'LGNd' or area == 'LGNv':
layer = self.network.find_conv_source_target('input', area)
layer_name = layer.source_name + layer.target_name
calc_graph[area] = nn.ReLU(inplace=True)(self.BNs[area](self.Convs[layer_name](x)))
calc_graph[area] = nn.ReLU(inplace=True)(
self.BNs[area](
self.Convs[layer_name](x)
)
)
continue

for layer in self.network.layers:
if layer.target_name == area:
# mask = None
# if layer.source_name in self.layer_masks:
# mask = self.layer_masks[layer.source_name]
# if mask is None:
# mask = 1
layer_name = layer.source_name + layer.target_name
# if isinstance(mask, int):
# print(area, mask)
# else:
# print(area, mask.shape)
if area not in calc_graph:
calc_graph[area] = self.Convs[layer_name](calc_graph[layer.source_name])
calc_graph[area] = self.Convs[layer_name](
calc_graph[layer.source_name]
)
else:
calc_graph[area] = calc_graph[area] + self.Convs[layer_name](calc_graph[layer.source_name])
calc_graph[area] = nn.ReLU(inplace=True)(self.BNs[area](calc_graph[area]))

calc_graph2 = calc_graph.copy()
calc_graph[area] = nn.ReLU(inplace=True)(
self.BNs[area](
calc_graph[area]
)
)
if calc_graph[area].sum() == 0:
pdb.set_trace()

if len(area_list) == 1:
if flatten:
Expand All @@ -153,14 +222,12 @@ def get_img_feature(self, x, area_list, flatten=True):
re = None
for area in area_list:
if re is None:
re = torch.flatten(torch.nn.AdaptiveAvgPool2d(4) (calc_graph[area]), 1)
re = torch.nn.AdaptiveAvgPool2d(4) (calc_graph[area])
# re = torch.flatten(
# nn.ReLU(inplace=True)(self.BNs['%s_downsample'%area](self.Convs['%s_downsample'%area](calc_graph[area]))),
# 1)
else:
re=torch.cat([torch.flatten(
torch.nn.AdaptiveAvgPool2d(4) (calc_graph[area]),
1), re], axis=1)
re=torch.cat([torch.nn.AdaptiveAvgPool2d(4) (calc_graph[area]), re], axis=1)
# re=torch.cat([
# torch.flatten(
# nn.ReLU(inplace=True)(self.BNs['%s_downsample'%area](self.Convs['%s_downsample'%area](calc_graph[area]))),
Expand All @@ -176,6 +243,6 @@ def get_img_feature(self, x, area_list, flatten=True):
return re

def forward(self, x):
x = self.get_img_feature(x, OUTPUT_AREAS)
x = self.classifier(x)
x = self.get_img_feature(x, OUTPUT_AREAS, flatten=False)
# x = self.classifier(x)
return x
Loading