Skip to content

Commit

Permalink
fix: vit network did not conform to expected state dict
Browse files Browse the repository at this point in the history
  • Loading branch information
tbung committed Aug 9, 2023
1 parent 7545b32 commit 3500f91
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions fd_shifts/models/networks/vit.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import timm
import torch
import torch.nn as nn
Expand All @@ -10,8 +12,16 @@ class ViT(Network):
def __init__(self, cf: configs.Config):
super().__init__()

self.encoder = Encoder(cf)
self.classifier = Classifier(self.encoder.model.head)
self._encoder = Encoder(cf)
self._classifier = Classifier(self.encoder.model.head)

@property
def encoder(self) -> Encoder:
return self._encoder

@property
def classifier(self) -> Classifier:
return self._classifier

def forward(self, x):
out = self.encoder(x)
Expand Down

0 comments on commit 3500f91

Please sign in to comment.