Skip to content

Commit

Permalink
Fix bugs in pytorch backend get_layer, amber-cli, model ckpt; Change …
Browse files Browse the repository at this point in the history
…default arg vals for GeneralController and Env
  • Loading branch information
zj-zhang committed Jan 18, 2023
1 parent aa57417 commit 01cf1d4
Show file tree
Hide file tree
Showing 15 changed files with 101 additions and 87 deletions.
2 changes: 0 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
<a id='sec1'></a>
AMBER is a toolkit for designing high-performance neural network models automatically in
Genomics and Bioinformatics.
AMBER is a toolkit for designing high-performance neral network models automatically in
Genomics and Bioinformatics.

🧐**AMBER can be used to automatically build:**
- 🟢 Convolution neural networks
Expand Down
41 changes: 24 additions & 17 deletions amber/architect/optim/controller/controllerTrainEnv.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,15 +114,16 @@ def __init__(self,
controller,
manager,
max_episode=100,
max_step_per_ep=2,
max_step_per_ep=3,
logger=None,
resume_prev_run=False,
should_plot=True,
initial_buffering_queue=15,
working_dir='.', entropy_converge_epsilon=0.01,
initial_buffering_queue=10,
working_dir='.',
early_stop_patience=None,
squeezed_action=True,
with_input_blocks=False,
with_skip_connection=True,
with_skip_connection=False,
save_controller=True,
continuous_run=False,
verbose=0,
Expand All @@ -136,7 +137,7 @@ def __init__(self,
self.working_dir = working_dir
self.total_reward = 0
self.entropy_record = []
self.entropy_converge_epsilon = entropy_converge_epsilon
self.early_stop_patience = int(early_stop_patience) if early_stop_patience is not None else np.inf
self.squeezed_action = squeezed_action
self.with_input_blocks = with_input_blocks
self.with_skip_connection = with_skip_connection
Expand All @@ -146,6 +147,7 @@ def __init__(self,
self.verbose = verbose
self.logger = logger if logger else setup_logger(working_dir)

self.logger.info(f"working directory: {os.path.realpath(self.working_dir)}")
self.time_budget = kwargs.pop('time_budget', "72:00:00")
if self.time_budget is None:
pass
Expand Down Expand Up @@ -178,7 +180,7 @@ def __init__(self,
self.clean()

def __str__(self):
s = 'ControllerTrainEnv for %i max steps, %i child mod. each step' % (self.max_episode, self.max_step_per_ep)
s = 'ControllerTrainEnv for %i max steps, %i child model each step' % (self.max_episode, self.max_step_per_ep)
return s

def restore(self):
Expand Down Expand Up @@ -248,17 +250,18 @@ def train(self):
f = open(os.path.join(self.working_dir, 'train_history.csv'), mode='w')
writer = csv.writer(f)
starttime = datetime.datetime.now()
best_reward = - np.inf
patience_cnt = 0
for ep in range(self.start_ep, self.max_episode):
try:
# reset env
state = self.reset()
ep_reward = 0
loss_and_metrics_ep = {'knowledge': 0, 'acc': 0, 'loss': 0}
if 'metrics' in self.manager.model_compile_dict:
loss_and_metrics_ep.update({x: 0 for x in self.manager.model_compile_dict['metrics']})

ep_reward = 0
ep_probs = []

for step in range(self.max_step_per_ep):
# value = self.controller.get_value(state)
actions, probs = self.controller.get_action(state) # get an action for the previous state
Expand Down Expand Up @@ -329,22 +332,26 @@ def train(self):
self.controller.save_weights(
os.path.join(self.working_dir, "controller_weights.h5"))

# TODO: add early-stopping
# check the entropy record and stop training if no progress was made
# (less than entropy_converge_epsilon)
# if ep >= self.max_episode//3 and \
# np.std(self.entropy_record[-(self.max_step_per_ep):])<self.entropy_converge_epsilon:
# LOGGER.info("Controller converged at episode %i"%ep)
# break
except KeyboardInterrupt:
LOGGER.info("User disrupted training")
break

# add early-stopping
# check the best reward patience and stop training if no progress was made more than early_stop_patience
if ep_reward > best_reward:
best_reward = ep_reward
patience_cnt = 0
else:
patience_cnt += 1
consumed_time = (datetime.datetime.now() - starttime).total_seconds()
LOGGER.info("used time: %.2f %%" % (consumed_time / self.time_budget * 100))
LOGGER.info("Iter %i, this reward: %.3f, best reward: %.3f, used time: %.2f %%" % (ep, ep_reward, best_reward, consumed_time / self.time_budget * 100))

if consumed_time >= self.time_budget:
LOGGER.info("training ceased because run out of time budget")
break
if ep >= self.initial_buffering_queue+self.early_stop_patience and patience_cnt > self.early_stop_patience:
LOGGER.info("Controller search early-stopped at episode %i"%ep)
break

LOGGER.debug("Total Reward : %s" % self.total_reward)

Expand All @@ -364,7 +371,7 @@ def train(self):
**save_kwargs)
save_stats(loss_and_metrics_list, self.working_dir, is_resume_run=self.resume_prev_run)

if self.should_plot:
if self.should_plot is True:
plot_action_weights(self.working_dir)
plot_wiring_weights(self.working_dir, self.with_input_blocks, self.with_skip_connection)
plot_stats2(self.working_dir)
Expand Down
6 changes: 3 additions & 3 deletions amber/architect/optim/controller/pytorch/generalController.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ class GeneralController(BaseController):
"""

def __init__(self, model_space, buffer_type='ordinal', with_skip_connection=True, share_embedding=None,
def __init__(self, model_space, buffer_type='ordinal', with_skip_connection=False, share_embedding=None,
use_ppo_loss=False, kl_threshold=0.05, skip_connection_unique_connection=False, buffer_size=15,
batch_size=5, session=None, train_pi_iter=20, lstm_size=32, lstm_num_layers=2, lstm_keep_prob=1.0,
tanh_constant=None, temperature=None, optim_algo="adam", skip_target=0.8, skip_weight=None,
Expand Down Expand Up @@ -215,7 +215,7 @@ def _build_train_op(self, input_arc, advantage, old_probs):
old_probs = [F.cast(p, F.float32) for p in old_probs]
self._build_trainer(input_arc=input_arc)
normalize = F.cast(self.num_layers * (self.num_layers - 1) / 2, F.float32)
self.skip_rate = F.cast(self.skip_count, F.float32) / normalize
self.skip_rate = F.cast(self.skip_count, F.float32) / normalize if self.with_skip_connection else None
loss = 0
if self.with_skip_connection is True and self.skip_weight is not None:
loss += self.skip_weight * F.reduce_mean(self.onehot_skip_penaltys)
Expand Down Expand Up @@ -312,4 +312,4 @@ def save_weights(self, filepath):
# hf.create_dataset(name=self.weights[i].name, data=d)

def load_weights(self, *args):
pass
pass
2 changes: 1 addition & 1 deletion amber/architect/optim/controller/tf1/generalController.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ class GeneralController(BaseController):
"""

def __init__(self, model_space, buffer_type='ordinal', with_skip_connection=True, share_embedding=None,
def __init__(self, model_space, buffer_type='ordinal', with_skip_connection=False, share_embedding=None,
use_ppo_loss=False, kl_threshold=0.05, skip_connection_unique_connection=False, buffer_size=15,
batch_size=5, session=None, train_pi_iter=20, lstm_size=32, lstm_num_layers=2, lstm_keep_prob=1.0,
tanh_constant=None, temperature=None, optim_algo="adam", skip_target=0.8, skip_weight=None,
Expand Down
4 changes: 2 additions & 2 deletions amber/architect/optim/controller/tf2/generalController.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ class GeneralController(BaseController):
"""

def __init__(self, model_space, buffer_type='ordinal', with_skip_connection=True, share_embedding=None,
def __init__(self, model_space, buffer_type='ordinal', with_skip_connection=False, share_embedding=None,
use_ppo_loss=False, kl_threshold=0.05, skip_connection_unique_connection=False, buffer_size=15,
batch_size=5, session=None, train_pi_iter=20, lstm_size=32, lstm_num_layers=2, lstm_keep_prob=1.0,
tanh_constant=None, temperature=None, optim_algo="adam", skip_target=0.8, skip_weight=None,
Expand Down Expand Up @@ -308,4 +308,4 @@ def save_weights(self, filepath):
# hf.create_dataset(name=self.weights[i].name, data=d)

def load_weights(self, *args):
pass
pass
9 changes: 5 additions & 4 deletions amber/backend/pytorch/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def forward(self, x):

class GlobalMaxPooling1DLayer(torch.nn.Module):
def forward(self, x):
return torch.max(x, dim=-1)
return torch.max(x, dim=-1).values


def get_layer(x=None, op=None, custom_objects=None, with_bn=False):
Expand Down Expand Up @@ -79,7 +79,7 @@ def get_layer(x=None, op=None, custom_objects=None, with_bn=False):
raise ValueError("unknown activation layer: %s" % actv_fn)

elif op.Layer_type == 'dense':
actv_fn = op.Layer_attributes.pop('activation', 'linear')
actv_fn = op.Layer_attributes.get('activation', 'linear')
curr_shape = np.array(x.shape) if isinstance(x, torch.Tensor) else x.out_features
assert len(curr_shape)==1, ValueError("dense layer must have 1-d prev layers")
_list = [torch.nn.Linear(in_features=curr_shape[0], out_features=op.Layer_attributes['units'])]
Expand All @@ -89,14 +89,15 @@ def get_layer(x=None, op=None, custom_objects=None, with_bn=False):

elif op.Layer_type == 'conv1d':
assert x is not None
actv_fn = op.Layer_attributes.pop('activation', 'linear')
actv_fn = op.Layer_attributes.get('activation', 'linear')
curr_shape = np.array(x.shape) if isinstance(x, torch.Tensor) else x.out_features
assert len(curr_shape)==2, ValueError("conv1d layer must have 2-d prev layers")
# XXX: pytorch assumes channel_first, unlike keras
_list = [torch.nn.Conv1d(in_channels=curr_shape[-2], out_channels=op.Layer_attributes['filters'],
kernel_size=op.Layer_attributes.get('kernel_size',1),
stride=op.Layer_attributes.get('strides', 1),
padding=op.Layer_attributes.get('padding', 0),
dilation=op.Layer_attributes.get('dilation_rate', 1),
)]
if with_bn: _list.append(get_layer(x=x, op=Operation("batchnorm")))
_list.append(get_layer(op=Operation("activation", activation=actv_fn)))
Expand All @@ -105,7 +106,7 @@ def get_layer(x=None, op=None, custom_objects=None, with_bn=False):
# elif op.Layer_type == 'conv2d':
# if with_bn is True:
# assert x is not None
# actv_fn = op.Layer_attributes.pop('activation', 'linear')
# actv_fn = op.Layer_attributes.get('activation', 'linear')
# x = tf.keras.layers.Conv2D(**op.Layer_attributes)(x)
# x = tf.keras.layers.BatchNormalization()(x)
# x = tf.keras.layers.Activation(actv_fn)(x)
Expand Down
43 changes: 36 additions & 7 deletions amber/backend/pytorch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,17 @@
import torchmetrics
import os
import numpy as np
import warnings
from . import cache
from .layer import get_layer
from .utils import InMemoryLogger
from .tensor import TensorType

# disable lightning logging
import logging
logging.getLogger('lightning').setLevel(0)
logging.getLogger("pytorch_lightning").setLevel(logging.WARNING)


class Model(pl.LightningModule):
def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -84,9 +90,13 @@ def fit(self, x, y=None, validation_data=None, batch_size=32, epochs=1, nsteps=N
callbacks=callbacks,
enable_progress_bar=verbose,
logger=logger,
enable_model_summary=False,
# deterministic=True,
)
self.trainer.fit(self, train_data, validation_data)
# just too many UserWarnings from lightning
with warnings.catch_warnings():
warnings.simplefilter("ignore")
self.trainer.fit(self, train_data, validation_data)
return logger

def predict(self, x, y=None, batch_size=32, verbose=False):
Expand All @@ -108,12 +118,28 @@ def evaluate(self, x, y=None, batch_size=32, verbose=False):
accelerator="auto",
max_epochs=1,
enable_progress_bar=verbose,
enable_model_summary=False,
# deterministic=True,
)
res = trainer.test(self, data, verbose=verbose)[0]
res = {k.replace('test', 'val'):v for k,v in res.items()}
return res

def save_weights(self, filepath: str):
torch.save(self.state_dict(), filepath)

def load_weights(self, filepath: str):
self.load_state_dict(torch.load(filepath), strict=False)

def save(self, filepath: str):
self.trainer.save_checkpoint(filepath)

def load(self, filepath: str):
return self.load_from_checkpoint(filepath, strict=False)

def summary(self):
print(str(self))

def configure_optimizers(self):
"""Set up optimizers and schedulers.
Uses Adam, learning rate from `self.lr`, and no scheduler by default.
Expand Down Expand Up @@ -228,8 +254,6 @@ def test_epoch_end(self, outputs):
self.log(f"test_{name}", metric.compute(), prog_bar=True)
metric.reset()

def save_weights(self, *args):
pass


class Sequential(Model):
Expand All @@ -254,7 +278,7 @@ def get_metric(m):
return m
elif type(m) is str:
if m.lower() == 'kl_div':
return torch.nn.KLDivLoss()
return torch.nn.KLDivLoss(reduction='batchmean')
elif m.lower() in ('acc', 'accuracy'):
return torchmetrics.Accuracy
elif m.lower() in ('f1', 'f1_score'):
Expand Down Expand Up @@ -312,16 +336,21 @@ def get_callback(m):
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
return EarlyStopping
elif m == 'ModelCheckpoint':
# see more APIs in https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.callbacks.ModelCheckpoint.html
from pytorch_lightning.callbacks import ModelCheckpoint
def ModelCheckpoint_(filename, monitor='val_loss', mode='min', save_best_only=True, verbose=False):
return ModelCheckpoint(
filename_prefix = '.'.join(filename.split('.')[:-1])
filename_suffix = filename.split('.')[-1]
model_ckpt = ModelCheckpoint(
dirpath=os.path.dirname(filename),
filename=os.path.basename(filename),
filename=os.path.basename(filename_prefix),
save_top_k=1 if save_best_only else None,
monitor=monitor,
mode=mode,
verbose=verbose
)
)
model_ckpt.FILE_EXTENSION = '.'+filename_suffix
return model_ckpt
return ModelCheckpoint_


Expand Down
8 changes: 8 additions & 0 deletions amber/backend/pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,14 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
"""Log metrics, associating with given `step`."""
metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR)
self.experiment.log_metrics(metrics, step)

@property
def history(self) -> pd.DataFrame:
df = pd.DataFrame(self.experiment.metrics)
if len(df) > 0:
df.set_index("epoch", inplace=True)
df = df.groupby("epoch").mean()
return df

def pandas(self):
"""Return recorded metrics in a Pandas dataframe.
Expand Down
1 change: 0 additions & 1 deletion amber/modeler/sequential/pytorch_sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ def build(self, model_states):
# permute x if not a vector to match channel_last data format
if len(curr_shape) > 1:
dims = [0, len(curr_shape)] + np.arange(1, len(curr_shape)).tolist()
print(dims)
layer = F.get_layer(op=F.Operation('permute', dims=dims))
model.add(layer)
x = layer(x)
Expand Down
4 changes: 2 additions & 2 deletions amber/plots/plotsV1.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,14 +222,14 @@ def plot_stats2(working_dir):
d = np.stack(list(map(lambda x: sma(x), np.array(data))), axis=0)
avg = np.apply_along_axis(np.mean, 0, d)
ax = sns.lineplot(x=np.arange(1, len(avg) + 1), y=avg,
color='b', label='Loss', legend=False)
color='b', label='Loss/Reward', legend=False)
if d.shape[0] >= 6:
std = np.apply_along_axis(np.std, 0, d) / np.sqrt(d.shape[0])
min_, max_ = avg - 1.96 * std, avg + 1.96 * std
ax.fill_between(range(avg.shape[0]), min_, max_, alpha=0.2)

data = df['Knowledge']
if np.array(data).shape[1] > 0: # if have data
if np.array(data).shape[1] > 0 and np.std(data) > 0: # if have data
ax2 = ax.twinx()
d = np.stack(list(map(lambda x: sma(x), np.array(data))), axis=0)
avg = np.apply_along_axis(np.mean, 0, d)
Expand Down
2 changes: 1 addition & 1 deletion amber/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
try:
import tensorflow as static_tf
if static_tf.__version__.startswith("2"):
print("detected tf2 - using compatibility mode")
#print("detected tf2 - using compatibility mode")
#static_tf.compat.v1.disable_eager_execution()
import tensorflow.compat.v1 as static_tf
except ImportError:
Expand Down
Loading

0 comments on commit 01cf1d4

Please sign in to comment.