Skip to content

BerryWei/Material_mask_autoencoder

Repository files navigation

MMAE: Foundation model for composite microstructures

Overview

This repository contains the code and instructions for training and evaluating the Masked Material Autoencoder (MMAE) as described in the paper "Foundation model for composite microstructures: Reconstruction, stiffness, and nonlinear behavior prediction". The MMAE is used for self-supervised pre-training on composite material microstructures and then fine-tuned or used in linear probing for downstream tasks such as predicting homogenized stiffness components.

Citation
If you use this repository or the MMAE model in your work, please cite:

@article{wei2025foundation,
  title={Foundation model for composite microstructures: Reconstruction, stiffness, and nonlinear behavior prediction},
  author={Wei, Ting-Ju and Chen, Chuin-Shan},
  journal={Materials \& Design},
  volume={257},
  pages={114397},
  year={2025},
  publisher={Elsevier}
}

Architecture

Below is the architecture of the MMAE model:

Model Architecture

Model Architecture

Reconstruction Example

Here is an example of the reconstructed microstructure from the MMAE model:

Reconstruction Result

Necessary Packages and Versions

To run the code, you'll need the following packages:

torch==2.0.1
torchvision==0.15.2
torchaudio==2.0.2
timm==0.4.5
scikit-learn==1.3.0
scipy==1.10.1
numpy==1.24.3
pandas==2.0.2
matplotlib==3.7.1
joblib==1.2.0
fsspec==2023.6.0

Note: Ensure compatibility between package versions, Python version (e.g., Python 3.12), and CUDA version if using GPU acceleration.

Preparation

  1. Install Required Packages

  2. Download the Datasets

    Download the datasets for pre-training and transfer learning from Zenodo:

Stage 1: MMAE Pre-Training

To pre-train the MMAE, run:

python main_pretrain.py \
    --data_path /path/to/inclusion_train_100k \
    --batch_size 100 \
    --epochs 400 \
    --mask_ratio 0.85 \
    --output_dir /path/to/output_dir \
    --log_dir /path/to/log_dir \
    --num_workers 4

Parameters:

  • --data_path: Path to the pre-training dataset (inclusion_train_100k).
  • --batch_size: Batch size for training (e.g., 100).
  • --epochs: Number of training epochs (e.g., 400).
  • --mask_ratio: Masking ratio for the MAE (e.g., 0.85).
  • --output_dir: Directory to save training outputs and checkpoints.
  • --log_dir: Directory for logging training progress.
  • --num_workers: Number of worker threads for data loading.

Notes:

  • You can overwrite configurations by passing arguments with the corresponding key names.
  • All stdout messages and including checkpoints are stored in the specified --output_dir.

Stage 2: Transfer Learning

Linear Probing

To perform linear probing, run:

python main_linprobe.py \
    --data_path /path/to/downstream_dataset \
    --finetune /path/to/pretrained_mmae/checkpoint-399.pth \
    --output_dir /path/to/output_dir \
    --log_dir /path/to/log_dir \
    --cls_token \
    --numDataset NUM_DATASET \
    --target_col_name TARGET_COLUMN

Parameters:

  • --data_path: Path to the downstream dataset (e.g., downstream_short_fiber/train).
  • --finetune: Path to the pre-trained MMAE checkpoint.
  • --output_dir: Directory to save outputs and logs.
  • --log_dir: Directory for logging.
  • --cls_token: Use the [CLS] token for embedding.
  • --numDataset: Number of data samples to use.
  • --target_col_name: The target column name in the dataset (e.g., C1111, C2222, C1212).

Example:

python main_linprobe.py \
    --data_path ./downstream_short_fiber/train \
    --finetune ./output_dir/pretrained_mmae/checkpoint-399.pth \
    --output_dir ./output_dir/linear_probing \
    --log_dir ./logs/linear_probing \
    --cls_token \
    --numDataset 5000 \
    --target_col_name C1111

End-to-End Fine-Tuning

To perform end-to-end fine-tuning, run:

python main_finetune.py \
    --data_path /path/to/downstream_dataset \
    --finetune /path/to/pretrained_mmae/checkpoint-399.pth \
    --output_dir /path/to/output_dir \
    --log_dir /path/to/log_dir \
    --cls_token \
    --numDataset NUM_DATASET \
    --target_col_name TARGET_COLUMN

Partial Fine-Tuning

For partial fine-tuning, use the script main_finetune.py.

Run:

python main_finetune.py \
    --data_path /path/to/downstream_dataset \
    --finetune /path/to/pretrained_mmae/checkpoint-399.pth \
    --output_dir /path/to/output_dir \
    --log_dir /path/to/log_dir \
    --cls_token \
    --partial_fine_tuning \
    --num_tail_blocks NUM_TAIL_BLOCKS \
    --numDataset NUM_DATASET \
    --num_workers 2 \
    --target_col_name TARGET_COLUMN

Parameters:

  • --partial_fine_tuning: Indicates partial fine-tuning is enabled.
  • --num_tail_blocks: Number of transformer blocks at the end of the encoder to fine-tune (e.g., 2).
  • --num_workers: Number of worker threads for data loading.

Example:

python main_finetune.py \
    --data_path ./downstream_short_fiber/train \
    --finetune ./output_dir/pretrained_mmae/checkpoint-399.pth \
    --output_dir ./output_dir/partial_finetune \
    --log_dir ./logs/partial_finetune \
    --cls_token \
    --partial_fine_tuning \
    --num_tail_blocks 2 \
    --numDataset 5000 \
    --num_workers 2 \
    --target_col_name C1111

About

Code for the paper "Foundation Model for Composite Materials and Microstructural Analysis"

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages