Skip to content

MrZoiya/MIAM

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

6 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

MIAM: Modality Imbalance-Aware Masking 😋

Code for MIAM: Modality Imbalance-Aware Masking for Multimodal Ecological Applications (ICLR 2026 🇧🇷).

👥 Authors: Robin Zbinden*, Wesley Monteith-Finas*, Gencer Sumbul, Nina van Tiel, Chiara Vanalli, Devis Tuia
🌐 Website: https://zbirobin.github.io/publications/miam/
📄 Paper: https://openreview.net/forum?id=oljjAkgZN4
🤗 Model weights: https://huggingface.co/zbirobin/MIAM

Overview 🌍

Ecological modeling is central to conservation, climate adaptation, and environmental management. Modern ecological datasets are inherently multimodal, combining heterogeneous signals such as environmental tabular variables 📊, climate time series 📈, bioacoustics 🔊, natural images 📷, and satellite imagery 🛰️.

Learning and inference in this setting are challenging because:

  • inputs are often incomplete, with missing data occurring either at the modality level or within modalities
  • modalities contribute unequally, leading to modality imbalance during optimization
  • Interpretability is essential, both across and within modalities.

MIAM addresses these challenges with a dynamic, score-driven masking strategy that adapts during training to modality dominance and learning dynamics. A model trained with MIAM can then be evaluated under any combination of input tokens, supporting fine-grained contribution analysis across and within modalities, while improving robustness to missing data in ecological applications.

Checkout the project webpage for additional details and interactive visualizations: https://zbirobin.github.io/publications/miam/.

MIAM main figure MIAM adaptively increases masking on dominant modalities to mitigate modality imbalance and improve multimodal learning.

Why MIAM 🤿

Most masking strategies in multimodal learning are fixed and do not sufficiently explore the space of possible input subsets. This limits robustness to missing data and does not explicitly tackle modality imbalance, where dominant modalities hinder learning of complementary ones. We define masking strategies as distributions over the unit hypercube, where each dimension corresponds to a modality. Existing strategies can be expressed in this framework (e.g., constant, uniform, Dirichlet, modality dropout, opm):

MIAM masking strategies MIAM's dynamic masking strategy contrasts with fixed strategies that do not adapt to modality imbalance and underexplore modality combinations.

MIAM provides a principled alternative built on three key properties:

  • Full support: Masking probabilities are sampled over the entire hypercube, allowing any combination of masked and unmasked tokens to occur during training.
  • Corner prioritization: Rather than sampling uniformly, MIAM uses a corner-anchored mixture of product beta distributions. Each mixture component concentrates probability mass near one of the 2M hypercube corners, promoting training on informative subsets (e.g., single-modality or near-complete inputs).
  • Imbalance awareness: MIAM dynamically adjusts the sharpness of these beta distributions based on modality-specific learning dynamics. Modalities with high and stable unimodal performance are masked more frequently, encouraging the model to better optimize slower-learning or underutilized modalities.

Benchmarks In This Repository 🧪

Use the benchmark-specific READMEs for full reproduction instructions:

Benchmark Task Modalities Directory README
GeoPlant 🌿 Species distribution modeling 🛰️📊📈 maskSDM/ maskSDM/README.md
TaxaBench-8k 🐾 Multimodal species classification 🛰️📊🔊📷📍 taxabench/ taxabench/README.md
SatBird 🐦 Bird species distribution modeling 🛰️📊 satbird/ satbird/README.md

Repository Structure 🗂️

├── README.md
├── maskSDM/              # GeoPlant benchmark + core MIAM training/evaluation
│   ├── training/         # Training loop and masking logic
|   ├── modules/          # Model architectures and components
│   └── evaluation/       # Evaluation scripts and experiments
├── taxabench/            # TaxaBench standalone benchmark pipeline
├── satbird/              # SatBird standalone benchmark pipeline
├── models/               # Saved checkpoints (generated)
└── figures_tables/       # Figures and tables of the paper generation scripts

Quick Setup ⚙️

python -m venv .venv
source .venv/bin/activate
pip install -r requirements.txt
pip install -e .

Python 3.10+ is recommended (Python3.13 was used). For GPU support, ensure you have the appropriate CUDA toolkit installed and compatible PyTorch version.

Start Working with MIAM 🚀

The core MIAM code is implemented in maskSDM/ and can be adapted to new datasets and tasks. The main training loop is in maskSDM/training/trainer.py, which computes masking probabilities with maskSDM/training/masking.py and tracks modality-specific performance. The evaluation scripts in maskSDM/evaluation/ can be adapted to new tasks and metrics. The taxabench/ and satbird/ directories provide standalone pipelines for their respective benchmarks, which can also be used as templates for new datasets.

MIAM is driven by two concrete mechanisms in the code:

  1. Dynamic mask sampling (get_mask_prob)
  • Entry points: maskSDM/training/masking.py, satbird/masking.py, taxabench/masking.py.
  • Each training step computes per-modality masking probabilities mask_prob and samples visibility with mask = Bernoulli(1 - mask_prob).
  • For method="miam", mask_prob is sampled from a mixture of product Beta distributions over hypercube corners (sample_from_beta_mixture + get_corner_weights).
  • The per-modality mixture weight is adapted with: modality_weights as a function of rho_s, rho_d, and lambda to adapt masking to modality imbalance.
  1. Per-modality tracking and feedback (rho_s, rho_d)
  • Entry points: maskSDM/training/trainer.py, satbird/trainer.py, taxabench/trainer.py.
  • After each epoch, trainers evaluate modality-specific validation metrics (e.g., AUROC in GeoPlant) and compute:
    • rho_s: normalized modality scores using a geometric-mean baseline (modality performance).
    • rho_d: normalized changes in modality score in the last epoch (modality learning speed).
  • These values are fed back into get_mask_prob on the next epochs, increasing masking on dominant modalities and helping under-optimized ones.

Minimal loop to follow in code:

  • compute mask_prob from current rho_s, rho_d
  • sample/apply modality mask in forward pass
  • run validation per modality
  • update rho_s, rho_d

This repository builds on ideas and implementation from the MaskSDM codebase.

Main Contributions ✨

  • Dynamic, imbalance-aware masking (MIAM) that mitigates modality imbalance, promotes complementary multimodal learning, and improves robustness to missing data in ecological applications ⚖️
  • Consistent improvements on ecological multimodal benchmarks (GeoPlant and TaxaBench), especially for under-optimized modalities 📊
  • Fine-grained interpretability across and within modalities (variables, time segments, image patches) 🔎

More details on the method, experiments, and results can be found in the paper and the website.

Citation 📄

If you use this code for your research, please cite the paper:

@inproceedings{
    zbinden2026miam,
    title={{MIAM}: Modality Imbalance-Aware Masking for Multimodal Ecological Applications},
    author={Robin Zbinden and Wesley Monteith-Finas and Gencer Sumbul and Nina van Tiel and Chiara Vanalli and Devis Tuia},
    booktitle={International Conference on Learning Representations (ICLR)},
    year={2026},
    url={https://openreview.net/forum?id=oljjAkgZN4}
}

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages

  • Jupyter Notebook 88.6%
  • Python 11.0%
  • Shell 0.4%