This repository contains the code and instructions for training and evaluating the Masked Material Autoencoder (MMAE) as described in the paper "Foundation model for composite microstructures: Reconstruction, stiffness, and nonlinear behavior prediction". The MMAE is used for self-supervised pre-training on composite material microstructures and then fine-tuned or used in linear probing for downstream tasks such as predicting homogenized stiffness components.
Citation
If you use this repository or the MMAE model in your work, please cite:
@article{wei2025foundation,
title={Foundation model for composite microstructures: Reconstruction, stiffness, and nonlinear behavior prediction},
author={Wei, Ting-Ju and Chen, Chuin-Shan},
journal={Materials \& Design},
volume={257},
pages={114397},
year={2025},
publisher={Elsevier}
}Below is the architecture of the MMAE model:
Here is an example of the reconstructed microstructure from the MMAE model:
To run the code, you'll need the following packages:
torch==2.0.1
torchvision==0.15.2
torchaudio==2.0.2
timm==0.4.5
scikit-learn==1.3.0
scipy==1.10.1
numpy==1.24.3
pandas==2.0.2
matplotlib==3.7.1
joblib==1.2.0
fsspec==2023.6.0
Note: Ensure compatibility between package versions, Python version (e.g., Python 3.12), and CUDA version if using GPU acceleration.
-
Install Required Packages
-
Download the Datasets
Download the datasets for pre-training and transfer learning from Zenodo:
To pre-train the MMAE, run:
python main_pretrain.py \
--data_path /path/to/inclusion_train_100k \
--batch_size 100 \
--epochs 400 \
--mask_ratio 0.85 \
--output_dir /path/to/output_dir \
--log_dir /path/to/log_dir \
--num_workers 4Parameters:
--data_path: Path to the pre-training dataset (inclusion_train_100k).--batch_size: Batch size for training (e.g., 100).--epochs: Number of training epochs (e.g., 400).--mask_ratio: Masking ratio for the MAE (e.g., 0.85).--output_dir: Directory to save training outputs and checkpoints.--log_dir: Directory for logging training progress.--num_workers: Number of worker threads for data loading.
Notes:
- You can overwrite configurations by passing arguments with the corresponding key names.
- All stdout messages and including checkpoints are stored in the specified
--output_dir.
To perform linear probing, run:
python main_linprobe.py \
--data_path /path/to/downstream_dataset \
--finetune /path/to/pretrained_mmae/checkpoint-399.pth \
--output_dir /path/to/output_dir \
--log_dir /path/to/log_dir \
--cls_token \
--numDataset NUM_DATASET \
--target_col_name TARGET_COLUMNParameters:
--data_path: Path to the downstream dataset (e.g.,downstream_short_fiber/train).--finetune: Path to the pre-trained MMAE checkpoint.--output_dir: Directory to save outputs and logs.--log_dir: Directory for logging.--cls_token: Use the [CLS] token for embedding.--numDataset: Number of data samples to use.--target_col_name: The target column name in the dataset (e.g.,C1111,C2222,C1212).
Example:
python main_linprobe.py \
--data_path ./downstream_short_fiber/train \
--finetune ./output_dir/pretrained_mmae/checkpoint-399.pth \
--output_dir ./output_dir/linear_probing \
--log_dir ./logs/linear_probing \
--cls_token \
--numDataset 5000 \
--target_col_name C1111To perform end-to-end fine-tuning, run:
python main_finetune.py \
--data_path /path/to/downstream_dataset \
--finetune /path/to/pretrained_mmae/checkpoint-399.pth \
--output_dir /path/to/output_dir \
--log_dir /path/to/log_dir \
--cls_token \
--numDataset NUM_DATASET \
--target_col_name TARGET_COLUMNFor partial fine-tuning, use the script main_finetune.py.
Run:
python main_finetune.py \
--data_path /path/to/downstream_dataset \
--finetune /path/to/pretrained_mmae/checkpoint-399.pth \
--output_dir /path/to/output_dir \
--log_dir /path/to/log_dir \
--cls_token \
--partial_fine_tuning \
--num_tail_blocks NUM_TAIL_BLOCKS \
--numDataset NUM_DATASET \
--num_workers 2 \
--target_col_name TARGET_COLUMNParameters:
--partial_fine_tuning: Indicates partial fine-tuning is enabled.--num_tail_blocks: Number of transformer blocks at the end of the encoder to fine-tune (e.g., 2).--num_workers: Number of worker threads for data loading.
Example:
python main_finetune.py \
--data_path ./downstream_short_fiber/train \
--finetune ./output_dir/pretrained_mmae/checkpoint-399.pth \
--output_dir ./output_dir/partial_finetune \
--log_dir ./logs/partial_finetune \
--cls_token \
--partial_fine_tuning \
--num_tail_blocks 2 \
--numDataset 5000 \
--num_workers 2 \
--target_col_name C1111

