Skip to content

Commit

Permalink
Add save_hyperparam to pytorch sequential to enable save/load
Browse files Browse the repository at this point in the history
  • Loading branch information
zj-zhang committed Jan 23, 2023
1 parent 01cf1d4 commit c6641a8
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 8 deletions.
2 changes: 1 addition & 1 deletion amber/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from .wrapper import Amber
from . import architect, modeler, utils, plots

__version__ = "0.1.4"
__version__ = "0.1.5"

__all__ = [
'Amber',
Expand Down
4 changes: 3 additions & 1 deletion amber/backend/pytorch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(self, *args, **kwargs):
self.valid_metrics = {}
self.trainer = None
self.task = None
self.save_hyperparameters()

def compile(self, loss, optimizer, metrics=None, *args, **kwargs):
if callable(loss):
Expand Down Expand Up @@ -258,9 +259,10 @@ def test_epoch_end(self, outputs):

class Sequential(Model):
def __init__(self, layers=None):
layers = layers or []
super().__init__()
layers = layers or []
self.layers = torch.nn.ModuleList(layers)
self.save_hyperparameters()

def add(self, layer):
self.layers.append(layer)
Expand Down
9 changes: 5 additions & 4 deletions amber/modeler/sequential/pytorch_sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(self, inputs_op, output_op, model_compile_dict, model_space, custom
assert self.session is None or isinstance(self.session, F.SessionType)

def build(self, model_states):
model = F.Sequential()
layers = []
curr_shape = self.input_node.Layer_attributes['shape']
x = torch.empty(*curr_shape, device='cpu')
# add a batch dim
Expand All @@ -38,7 +38,7 @@ def build(self, model_states):
if len(curr_shape) > 1:
dims = [0, len(curr_shape)] + np.arange(1, len(curr_shape)).tolist()
layer = F.get_layer(op=F.Operation('permute', dims=dims))
model.add(layer)
layers.append(layer)
x = layer(x)
for i, state in enumerate(model_states):
if issubclass(type(state), int) or np.issubclass_(type(state), np.integer):
Expand All @@ -51,9 +51,10 @@ def build(self, model_states):
)
layer = F.get_layer(torch.squeeze(x, dim=0), op, custom_objects=self.custom_objects)
x = layer(x)
model.add(layer)
layers.append(layer)
out = F.get_layer(torch.squeeze(x, dim=0), op=self.output_node, custom_objects=self.custom_objects)
model.add(out)
layers.append(out)
model = F.Sequential(layers=layers)
return model

def __call__(self, model_states):
Expand Down
5 changes: 4 additions & 1 deletion amber/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,12 @@ def read_history_set(fn_list):


def read_history(fn_list,
metric_name_dict={'acc': 0, 'knowledge': 1, 'loss': 2}):
metric_name_dict=None):
if not isinstance(fn_list, (list, tuple)):
fn_list = [fn_list]
d = read_history_set(fn_list)
d.columns = ['ID', 'metrics', 'reward'] + ['L%i' % i for i in range(1, d.shape[1] - 3)] + ['dir']
metric_name_dict = {} if metric_name_dict is None else metric_name_dict
metrics = {x: [] for x in metric_name_dict}
for i in range(d.shape[0]):
tmp = d.iloc[i, 1].split(',')
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
config = {
'description': 'Automated Modelling in Biological Evidence-based Research',
'download_url': 'https://github.com/zj-zhang/AMBER',
'version': '0.1.4',
'version': '0.1.5',
'packages': find_packages(exclude=["*.tests", "*.tests.*", "tests.*", "tests"]),
'include_package_data': True,
'setup_requires': [],
Expand Down

0 comments on commit c6641a8

Please sign in to comment.