Synthesizing Deep Net Model Metamers#

plenoptic is compatible with any model written in pytorch, including deep neural networks from the model zoos timm and torchvision. In this exercise, we’ll grab ResNet50 from torchvision and show how to generate metamers for several of its intermediate representations, as done in Feather et al. 2023.

Attention

It is recommended that you first work through the Minimal metamer synthesis example exercise before this one! The optimization procedure here is a bit more complex and takes longer.

# needed for the plotting/animating:
import matplotlib.pyplot as plt
import plenoptic as po
import torch

plt.rcParams["animation.html"] = "html5"
# use single-threaded ffmpeg for animation writer
plt.rcParams["animation.writer"] = "ffmpeg"
plt.rcParams["animation.ffmpeg_args"] = ["-threads", "1"]
# so that relative sizes of axes created by po.plot.imshow and others look right
plt.rcParams["figure.dpi"] = 72
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
import numpy as np
import torchvision

When synthesizing model metamers for convolutional neural networks, researchers often pick a specific layer whose output they want to match. If we look at Feather et al. 2023 Figure 2e, we can see an interesting progression in layers 2 through 4: the layer 2 metamer looks almost identical to the target image, the layer 3 metamer starts to add RGB noise, and the layer 4 is almost completely unidentifiable, looking almost completely like random RGB noise. We’ll pick layer 3 from now, and you’re encouraged to try the other layers!

Warning

The contents of this exercise use features from plenoptic that have not yet been released. The names of the object and its methods may change between now and the release in version 2.1.0, later this July.

If you followed the setup instructions, you will have no problems, but if you install plenoptic directly with pip, you will get an AttributeError.

Use a model from torchvision#

First, let’s download the model weights for ResNet50 trained on ImageNet-1K and initialize the torchvision model.

weights = torchvision.models.ResNet50_Weights.IMAGENET1K_V1
deepnet = torchvision.models.resnet50(weights=weights)
Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /home/jenkins/agent/workspace/eurorse_plenoptic-workshops_main/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
  0%|          | 0.00/97.8M [00:00<?, ?B/s]
 23%|██▎       | 22.2M/97.8M [00:00<00:00, 233MB/s]
 50%|█████     | 49.2M/97.8M [00:00<00:00, 262MB/s]
 78%|███████▊  | 76.1M/97.8M [00:00<00:00, 271MB/s]
100%|██████████| 97.8M/97.8M [00:00<00:00, 269MB/s]

Next, we ensure that our model is in evaluation mode. Many models, including ResNet50, behave differently when in training and evaluation mode. In plenoptic, models are fixed and so we want the evaluation behavior:

deepnet.eval();

Next, we grab the preprocessing transform from the model. As the torchvision docs explain it (quoting version 0.27):

Before using the pre-trained models, one must preprocess the image (resize with right resolution/interpolation, apply inference transforms, rescale the values etc). There is no standard way to do this as it depends on how a given model was trained. It can vary across model families, variants or even weight versions. Using the correct preprocessing method is critical and failing to do so may lead to decreased accuracy or incorrect outputs.

For models trained on ImageNet, this preprocessing consists of two steps: resizing to a height and width of 224 pixels and normalizing the color channels (subtracting means and dividing by standard deviations). Following Feather et al. 2023 we recommend including the normalization step in the model for metamer synthesis, but handling the image resizing externally. We demonstrate how to do so below.

transform = weights.transforms()
print(transform)
norm = torchvision.transforms.Normalize(transform.mean, transform.std)
ImageClassification(
    crop_size=[224]
    resize_size=[256]
    mean=[0.485, 0.456, 0.406]
    std=[0.229, 0.224, 0.225]
    interpolation=InterpolationMode.BILINEAR
)

Finally, we’ll pass our neural network, target layer, and preprocessing transform to plenoptic’s FeatureExtractorModel, moving it to our specified device and removing the gradient from all learnable parameters (as models in plenoptic are fixed):

model = po.models.FeatureExtractorModel(deepnet, "layer3", norm)
model.to(DEVICE)
po.remove_grad(model)

Preparing the image#

Now, let’s prepare the image. The input image needs to be an RGB image with a height and width of 224 pixels. It should probably also be like those found in ImageNet: a single object in the center of the frame that belongs to one of the image classes. We’ll use one of the famous monkey selfies, and resize it appropriately:

img = po.data.macaque().to(DEVICE)
# here we downsample the original image by a factor of 4 and then lop off the bottom.
# that way, when we take the central 224 pixels in the following block, we end up with a
# decent image.
img = po.process.blur_downsample(img, 2)[..., :-59, :]

As discussed above, models trained on ImageNet should be passed an image of size 224 by 224. We’ll use plenoptic’s plenoptic.process.center_crop() to do so, grabbing the required size directly from the model’s associated transform;

img = po.process.center_crop(img, transform.crop_size[0])
po.plot.imshow(img, as_rgb=True);
../_images/8ee28f488caa1fbf4045361d917ae93b33204305747ad999896d4e259db05462.png

Synthesize model metamers!#

Now we can use the above like any other model we’ve used so far, though note that we need to tweak some of the optimization hyperparameters. Here, like in the Feather et al. paper, we find better results if we gradually decrease the learning rate over synthesis (using StepLR to halve the learning rate every 3000 steps). We also use the standard version of Adam, rather than the AMSGrad variant (which is the default for plenoptic).

met = po.Metamer(img, model)
scheduler = torch.optim.lr_scheduler.StepLR
scheduler_kwargs = {"step_size": 3000, "gamma": 0.5}
met.setup(
    optimizer_kwargs={"amsgrad": False},
    scheduler=scheduler,
    scheduler_kwargs=scheduler_kwargs
)
# by setting stop_iters_to_check=max_iter, we ensure it keeps going through
# all iterations
met.synthesize(max_iter=6000, stop_iters_to_check=6000, store_progress=120)

And look at the output:

po.plot.synthesis_status(met, figsize=(15, 4.5));
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.0006575623..0.9999835].
../_images/bea9ed8a811f8deaf8aaf0c84b18257956b09ad2edb169b62dbf99eacf7a05ee.png

In the above plots, we can see the metamer in the leftmost subplot, the loss over synthesis iterations in the middle, and the representation error on the right:

  • Our metamer match the results discussed earlier in this notebook: as a layer 3 metamer, it looks like the original image with some RGB noise added.

  • We can see that the optimization performed reasonably well: the loss decreased gradually over synthesis. If you were using these stimuli in an experiment, it may be worth continuing a bit more to get the loss even lower, but these demonstrate the point.

  • The representation error plot has two subplots. The first plots the average across channel, the average spatial representation, while the second averages across space to get a per-channel average representation. We can see that, while there’s some variation across both channels and space, there’s not an obvious outlier whose error we have been unable to constrain.

As before, we can animate to see this process over time:

po.plot.synthesis_animate(met, figsize=(15, 4.5))
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.0049055815..1.0050917].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.0024509684..1.0024201].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.0035695136..1.0010958].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.003536882..1.0023755].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.0023744977..1.000521].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.002194284..1.0010854].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.0025519368..1.0020452].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.002583689..1.000784].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.0036175186..1.0017728].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.002707168..1.002024].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.001523569..1.00113].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.000778536..1.0009739].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.0016350179..1.0014997].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.0017466596..1.0010796].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.0005791279..1.0009482].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.0007732041..1.0018252].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.0012390354..0.99997056].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.0019094121..1.0008918].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.0012555722..0.99991745].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.00042606096..1.0000252].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.0007835529..1.000188].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.000997167..1.0000731].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.001582015..1.0000614].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.001240352..1.0009716].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.0010332037..1.0000668].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.0002470237..0.9999218].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-6.7036995e-06..0.9999568].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.00012798597..0.99994546].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.000120688346..1.000131].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.00019865095..0.999981].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.00027554165..0.9999752].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.0003622857..1.0003166].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.00023003705..0.9999806].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.00068898825..0.9997991].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.00040765526..0.9999629].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.0005368194..1.0002937].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.00015356661..1.0003843].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.0008303943..0.9999357].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.0002185598..1.0003585].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.00056668627..1.0000238].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-2.8477378e-05..0.9999369].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.00038545384..1.0001953].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.0009943894..0.999899].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.00028426305..1.0000072].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.0006461258..1.0001223].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.00054394227..1.0000087].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.00031919958..0.99985725].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.00081911567..0.99995637].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.00017888866..1.0003784].
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.0006575623..0.9999835].

Understand the output#

The authors of Feather et al., 2023 used two additional checks to verify that metamer synthesis had succeeded (quotes from “Results > Metamer optimization” section, pdf page 5):

  • “the metamer had to result in the same classification decision by the model as the reference stimulus”

  • “measures of the match between the activations for the natural reference stimulus and its model metamer at the matched stage had to be much higher than would be expected by chance, as quantified with a null distribution”. The authors used three measures here: Pearson and Spearman correlations and signal-to-noise ratio. We could compute those measures, but without the null distribution, they’re difficult to interpret. So we just note that one should do something similar in order to verify synthesis has succeeded.

The following cell shows how to compute the image categories:

imagenet_categories = np.asarray(weights.meta["categories"])
def get_category(image):
    image_cat = po.to_numpy(
        torch.nn.functional.softmax(deepnet(norm(image)), dim=1).squeeze()
    )
    return imagenet_categories[image_cat.argmax()]

print(f"Original image category: {get_category(met.image)}")
print(f"Model metamer category: {get_category(met.metamer)}")
Original image category: guenon
Model metamer category: guenon

Different layer#

Look at figure 2e in Feather et al. 2023 and pick another layer to target. The hyperparameters we picked should work reasonably well for layers 2 and 4, but others have not been tested. Look at the output of synthesis_status() and tweak the hyperparameters as necessary to get the loss as low as possible!

target_layer = # WRITE SOMETHING HERE
model = po.models.FeatureExtractorModel(deepnet, target_layer, norm)
model.to(DEVICE)

met = po.Metamer(img, model)
met.setup(
    optimizer_kwargs={"amsgrad": False},
    scheduler=scheduler,
    scheduler_kwargs=scheduler_kwargs
)
# by setting stop_iters_to_check=max_iter, we ensure it keeps going through
# all iterations
met.synthesize(max_iter=6000, stop_iters_to_check=6000, store_progress=120)
po.plot.synthesis_status(met, figsize=(15, 4.5));

You can also specify multiple layers (as a list of strings, e.g., ["layer2", "layer3"]) to match multiple layers at once!

Different target image#

Try using a different target image than the one of macaque above and running metamer synthesis until completion:

Loading other images

Try one of the other included images or use plenoptic.load_images() to load one from disk.

img = # WRITE SOMETHING NEW HERE
img = img.to(DEVICE)
met = po.Metamer(img, model)
met.setup(
    optimizer_kwargs={"amsgrad": False},
    scheduler=scheduler,
    scheduler_kwargs=scheduler_kwargs
)
# by setting stop_iters_to_check=max_iter, we ensure it keeps going through
# all iterations
met.synthesize(max_iter=6000, stop_iters_to_check=6000, store_progress=120)
po.plot.synthesis_status(met, figsize=(15, 4.5));

And maybe animate to see what synthesis looks like?

po.plot.synthesis_animate(met, figsize=(15, 4.5))

Different initial image#

While the original paper initialized from a patch of white noise, it can be interesting to start from a different image as well. Using one of the same tools as above for loading another image, initialize metamer synthesis from another starting point and run it to completion:

met = po.Metamer(img, model)
met.setup(
    initial_image=, # WRITE SOMETHING HERE
    optimizer_kwargs={"amsgrad": False},
    scheduler=scheduler,
    scheduler_kwargs=scheduler_kwargs
)
met.synthesize(max_iter=6000, stop_iters_to_check=6000, store_progress=120)
po.plot.synthesis_status(met, figsize=(15, 4.5));

And maybe animate to see what synthesis looks like?

po.plot.synthesis_animate(met, figsize=(15, 4.5))

Different DeepNet#

This setup works for any torchvision model! Pick another model with pre-trained weights and a layer to target, and synthesize some model metamers!

weights = torchvision.models. # WRITE SOMETHING HERE!
deepnet = torchvision.models. # WRITE SOMETHING HERE!
deepnet.eval()

Don’t forget to grab their transform! How exactly this looks will depend on the model, but remember that we recommend any image-resizing transforms be handled externally and everything else be part of the metamer model.

# This should work for any ImageNet-trained model, but you'll have to do something else
# for other models
transform = weights.transforms()
print(transform)
norm = torchvision.transforms.Normalize(transform.mean, transform.std)

Now, specify the intermediate layer you want to match and initialize the plenoptic model!

target_layer = # WRITE SOMETHING HERE
model = po.models.FeatureExtractorModel(deepnet, target_layer, norm)
model.to(DEVICE)

And finally, instantiate the metamer and run synthesis. Note that the arguments to setup() will almost certainly need to be changed, but we’ve repeated the ones used for ResNet50 as a starting point. You may also need to change the loss function, see Texture synthesis for an example of changing this, and Metamer documentation for more details.

met = po.Metamer(img, model)
scheduler = torch.optim.lr_scheduler.StepLR
scheduler_kwargs = {"step_size": 3000, "gamma": 0.5}
met.setup(
    optimizer_kwargs={"amsgrad": False},
    scheduler=scheduler,
    scheduler_kwargs=scheduler_kwargs
)
met.synthesize(max_iter=6000, stop_iters_to_check=6000, store_progress=120)
po.plot.synthesis_status(met, figsize=(15, 4.5));