Brain-based image filtering
MedARC 3/29/2024
What is “image filtering”
It is hard to understate the importance of good data in training good image models (including generative models like Stable Diffusion or contrastive models like CLIP)
There is vast literature surrounding dataset filtering: how to take a massive dataset and remove bad samples (e.g., poor quality samples, duplicated samples)
A model trained with higher-quality samples will outperform a model trained with more samples with larger variation in quality
The aim of this project is to investigate a novel approach towards image filtering, making use of fMRI collected while humans looked at images in the MRI machine
How do ML researchers typically do image filtering
Why would the brain be at all relevant?
How would this actually work
Imagine we are training a CLIP model on a full, unfiltered dataset. We also have a separate train/val set of image/brain paired samples (from Allen et al., 2021).
We are training the CLIP model in batches of images. After every batch, we freeze the CLIP model and train a simple linear regression encoding model that goes from image → CLIP latent → brain. This brain prediction is the basis for a separate “fMRI loss” metric. We can then look back at the gradient to determine which images in the present batch were most responsible for improving or worsening the fMRI loss. We then remove k images from the batch from the entire unfiltered dataset and continue training until we have effectively pruned the dataset to a certain number of samples.
Practicalities
Collaborating over Discord and Github �(our repo: https://github.com/MedARC-AI/brain-image-filtering)