diff --git a/src/lib_model.py b/src/lib_model.py index bb38e4c..cd246a9 100644 --- a/src/lib_model.py +++ b/src/lib_model.py @@ -129,15 +129,15 @@ def build_classifier(config, num_classes, df_size, img_size): def find_unfreeze_points(model, mname, blocks_to_unfreeze): block_starts = [] print("\nModel Name =", mname) - if 'cn' in mname.lower(): + if 'cn' in mname[:2].lower(): for layer in model.layers: if 'stages' in layer.name and layer.name.endswith('_downsample_1_conv2d'): block_starts.append(layer.name) - elif 'en' in mname.lower(): + elif 'en' in mname[:2].lower(): for layer in model.layers: if 'block' in layer.name and layer.name.endswith('_0_conv_pw_conv2d'): block_starts.append(layer.name) - elif 'vit' in mname.lower(): + elif 'vt' in mname[:2].lower(): for layer in model.layers: if 'blocks' in layer.name and layer.name.endswith('_attn'): block_starts.append(layer.name)