Computing Task Aware/ Task Agnostic Matrices + terminology clarification #889
-
I have troubles understanding the evaluation protocol. Given splitcifar100, I want to train a model on experience 0 and test on all the other experiences, then I want to train on experience 1 (updating the corriuspettive head) and compute the accuracy for each other experiences...and so on... up to the last experience. The thing that I want is to update the correct head at training time, but at evaluation time I need two evaluation methods:
I think I'm not understanding the difference between stream task experience...so far I got this code. The end goal is to create a task aware matrix and a task agnostic accuracy matrix, where in the diagonal I have the performance of the model on the current experience, in the lower triangle the performance on old tasks and in upper triangle in (never seen) future tasks. Something like this:
So far I have this code, but I struggle to understand the evaluation plugin: import torch
import timm
from torch.nn import CrossEntropyLoss
from torch.optim import SGD
import avalanche
from avalanche.models import MultiHeadClassifier
from avalanche.training.strategies import Naive
from avalanche.benchmarks.classic import SplitCIFAR100
from avalanche.logging import InteractiveLogger
from avalanche.training.plugins import EvaluationPlugin
from avalanche.logging import InteractiveLogger, TextLogger, TensorboardLogger
from avalanche.training.plugins import EvaluationPlugin
from avalanche.evaluation import metrics
import torchvision
class MTResnet18(avalanche.models.MultiTaskModule):
def __init__(self, pretrained=False):
super().__init__()
self.resnet = timm.create_model('resnet18', pretrained=pretrained, num_classes=0)
self.classifier = MultiHeadClassifier(512)
def forward(self, x, task_labels):
out = self.resnet(x)
out = out.view(out.size(0), -1)
return self.classifier(out, task_labels)
SEED = 0
N_EXPERIENCES = 10
PRETRAINED = False
EPOCHS = 2
MINI_BATCH = 128
DEVICE = 0
device = torch.device(f"cuda:{DEVICE}"
if torch.cuda.is_available() and
DEVICE >= 0 else "cpu")
# Scenario
scenario = SplitCIFAR100(n_experiences=N_EXPERIENCES, seed=SEED, return_task_id=True)
# Model
model = MTResnet18(pretrained=PRETRAINED)
# Metrics and Logging
eval_plugin = EvaluationPlugin(
metrics.accuracy_metrics(epoch=False, experience=True, stream=False),
loggers=[InteractiveLogger()], benchmark=scenario)
# Strategy
optimizer = SGD(model.parameters(), lr=0.01)
criterion = CrossEntropyLoss()
strategy = Naive(
model=model, optimizer=optimizer, criterion=criterion,
train_mb_size=MINI_BATCH, train_epochs=EPOCHS, eval_mb_size=MINI_BATCH, device=device,
evaluator=eval_plugin)
results = []
for exp in scenario.train_stream:
print(f"Start of experience: {exp.current_experience}\nCurrent Classes: {exp.classes_in_this_experience}")
# Adds Head
model.adaptation(exp.dataset)
# Train
strategy.train(exp)
# Test
results.append(strategy.eval(scenario.test_stream))
print(results) |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 1 reply
-
Hi, thanks for reaching out! As for the concepts related to streams, experience and task, does this short introduction help you? Otherwise, let us know if you have specific doubts about the terminology. In your code, you are testing the evaluation setup 1. (task aware). So, the performance will be reported for a multi-headed model. The evaluation plugin simply computes the accuracy on each test experience and report the result back to you through InteractiveLogger. Each experience will have a growing task label associated to it, apart from the experience ID (in this case, they will be the same). If you want to test evaluation setup 2 (task-agnostic), you can set Note also that you don't have to manually call |
Beta Was this translation helpful? Give feedback.
-
First of all, let's clarify the definitions:
Unfortunately, the Now, an example of what your multi-head may look like:
you don't need to call |
Beta Was this translation helpful? Give feedback.
First of all, let's clarify the definitions:
Unfortunately, the
MultiHeadClassifier
does not support your use case but you can easily modify it to do it. Consider that in Avalanche you may have different task that reuse t…