{ "cells": [ { "cell_type": "markdown", "id": "3599d0f8", "metadata": {}, "source": [ "# Tutorial16: DIFFPOOL" ] }, { "cell_type": "code", "execution_count": null, "id": "4d1eca5c", "metadata": {}, "outputs": [], "source": [ "!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-1.10.0+cu113.html\n", "!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-1.10.0+cu113.html\n", "!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git" ] }, { "cell_type": "code", "execution_count": null, "id": "02d41129", "metadata": {}, "outputs": [], "source": [ "import torch" ] }, { "cell_type": "markdown", "id": "2b0f2059", "metadata": {}, "source": [ "Below are shown the computation to obtain the nodes features matrix and adjacency matrix for the first hierarchical step. \n", "\n", "Initial graph: \n", "```x_0 = 50 x 32\n", "adj_0 = 50 x 50```" ] }, { "cell_type": "code", "execution_count": null, "id": "4c7bdbf9", "metadata": {}, "outputs": [], "source": [ "# Node features matrix\n", "x_0 = torch.rand(50, 32)\n", "adj_0 = torch.rand(50,50).round().long()\n", "identity = torch.eye(50)\n", "adj_0 = adj_0 + identity" ] }, { "cell_type": "markdown", "id": "753e9e5c", "metadata": {}, "source": [ "Set the number of clusters we want to obtain at step 1" ] }, { "cell_type": "code", "execution_count": null, "id": "5b8647c0", "metadata": {}, "outputs": [], "source": [ "n_clusters_0 = 50\n", "n_clusters_1 = 5" ] }, { "cell_type": "markdown", "id": "c7c607e4", "metadata": {}, "source": [ "Initialize the weights of GNN_emb and GNN_pool, we use just 1 conv layer" ] }, { "cell_type": "code", "execution_count": null, "id": "1c420b7c", "metadata": {}, "outputs": [], "source": [ "w_gnn_emb = torch.rand(32, 16)\n", "w_gnn_pool = torch.rand(32, n_clusters_1)" ] }, { "cell_type": "markdown", "id": "3f410b12", "metadata": {}, "source": [ "\n", "" ] }, { "cell_type": "code", "execution_count": null, "id": "9264b7d7", "metadata": {}, "outputs": [], "source": [ "z_0 = torch.relu(adj_0 @ x_0 @ w_gnn_emb)\n", "s_0 = torch.softmax(torch.relu(adj_0 @ x_0 @ w_gnn_pool), dim=1)" ] }, { "cell_type": "markdown", "id": "aa5c0c75", "metadata": {}, "source": [ "\n", "" ] }, { "cell_type": "code", "execution_count": null, "id": "92465df5", "metadata": {}, "outputs": [], "source": [ "x_1 = s_0.t() @ z_0\n", "adj_1 = s_0.t() @ adj_0 @ s_0" ] }, { "cell_type": "code", "execution_count": null, "id": "6a8a7596", "metadata": {}, "outputs": [], "source": [ "print(x_1.shape)\n", "print(adj_1.shape)" ] }, { "cell_type": "code", "execution_count": null, "id": "9c695ca4", "metadata": {}, "outputs": [], "source": [ "import os.path as osp\n", "from math import ceil\n", "\n", "import torch\n", "import torch.nn.functional as F\n", "from torch_geometric.datasets import TUDataset\n", "import torch_geometric.transforms as T\n", "from torch_geometric.data import DenseDataLoader\n", "from torch_geometric.nn import DenseGCNConv as GCNConv, dense_diff_pool\n" ] }, { "cell_type": "code", "execution_count": null, "id": "3eb52795", "metadata": {}, "outputs": [], "source": [ "max_nodes = 150\n", "\n", "\n", "class MyFilter(object):\n", " def __call__(self, data):\n", " return data.num_nodes <= max_nodes\n", "\n", "\n", "dataset = TUDataset('data', name='PROTEINS', transform=T.ToDense(max_nodes),\n", " pre_filter=MyFilter())\n", "dataset = dataset.shuffle()\n", "n = (len(dataset) + 9) // 10\n", "test_dataset = dataset[:n]\n", "val_dataset = dataset[n:2 * n]\n", "train_dataset = dataset[2 * n:]\n", "test_loader = DenseDataLoader(test_dataset, batch_size=32)\n", "val_loader = DenseDataLoader(val_dataset, batch_size=32)\n", "train_loader = DenseDataLoader(train_dataset, batch_size=32)" ] }, { "cell_type": "code", "execution_count": null, "id": "186a211b", "metadata": {}, "outputs": [], "source": [ "class GNN(torch.nn.Module):\n", " def __init__(self, in_channels, hidden_channels, out_channels,\n", " normalize=False, lin=True):\n", " super(GNN, self).__init__()\n", " \n", " self.convs = torch.nn.ModuleList()\n", " self.bns = torch.nn.ModuleList()\n", " \n", " self.convs.append(GCNConv(in_channels, hidden_channels, normalize))\n", " self.bns.append(torch.nn.BatchNorm1d(hidden_channels))\n", " \n", " self.convs.append(GCNConv(hidden_channels, hidden_channels, normalize))\n", " self.bns.append(torch.nn.BatchNorm1d(hidden_channels))\n", " \n", " self.convs.append(GCNConv(hidden_channels, out_channels, normalize))\n", " self.bns.append(torch.nn.BatchNorm1d(out_channels))\n", "\n", "\n", " def forward(self, x, adj, mask=None):\n", " batch_size, num_nodes, in_channels = x.size()\n", " \n", " for step in range(len(self.convs)):\n", " x = self.bns[step](F.relu(self.convs[step](x, adj, mask)))\n", " \n", "\n", " return x\n", "\n", "\n", "class DiffPool(torch.nn.Module):\n", " def __init__(self):\n", " super(DiffPool, self).__init__()\n", "\n", " num_nodes = ceil(0.25 * max_nodes)\n", " self.gnn1_pool = GNN(dataset.num_features, 64, num_nodes)\n", " self.gnn1_embed = GNN(dataset.num_features, 64, 64)\n", "\n", " num_nodes = ceil(0.25 * num_nodes)\n", " self.gnn2_pool = GNN(64, 64, num_nodes)\n", " self.gnn2_embed = GNN(64, 64, 64, lin=False)\n", "\n", " self.gnn3_embed = GNN(64, 64, 64, lin=False)\n", "\n", " self.lin1 = torch.nn.Linear(64, 64)\n", " self.lin2 = torch.nn.Linear(64, dataset.num_classes)\n", "\n", " def forward(self, x, adj, mask=None):\n", " s = self.gnn1_pool(x, adj, mask)\n", " x = self.gnn1_embed(x, adj, mask)\n", "\n", " x, adj, l1, e1 = dense_diff_pool(x, adj, s, mask)\n", " #x_1 = s_0.t() @ z_0\n", " #adj_1 = s_0.t() @ adj_0 @ s_0\n", " \n", " s = self.gnn2_pool(x, adj)\n", " x = self.gnn2_embed(x, adj)\n", "\n", " x, adj, l2, e2 = dense_diff_pool(x, adj, s)\n", "\n", " x = self.gnn3_embed(x, adj)\n", "\n", " x = x.mean(dim=1)\n", " x = F.relu(self.lin1(x))\n", " x = self.lin2(x)\n", " return F.log_softmax(x, dim=-1), l1 + l2, e1 + e2\n", "\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "id": "a0e89716", "metadata": {}, "outputs": [], "source": [ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "model = DiffPool().to(device)\n", "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n", "\n", "\n", "def train(epoch):\n", " model.train()\n", " loss_all = 0\n", "\n", " for data in train_loader:\n", " data = data.to(device)\n", " optimizer.zero_grad()\n", " output, _, _ = model(data.x, data.adj, data.mask)\n", " loss = F.nll_loss(output, data.y.view(-1))\n", " loss.backward()\n", " loss_all += data.y.size(0) * loss.item()\n", " optimizer.step()\n", " return loss_all / len(train_dataset)\n", "\n", "\n", "@torch.no_grad()\n", "def test(loader):\n", " model.eval()\n", " correct = 0\n", "\n", " for data in loader:\n", " data = data.to(device)\n", " pred = model(data.x, data.adj, data.mask)[0].max(dim=1)[1]\n", " correct += pred.eq(data.y.view(-1)).sum().item()\n", " return correct / len(loader.dataset)\n", "\n", "\n", "best_val_acc = test_acc = 0\n", "for epoch in range(1, 151):\n", " train_loss = train(epoch)\n", " val_acc = test(val_loader)\n", " if val_acc > best_val_acc:\n", " test_acc = test(test_loader)\n", " best_val_acc = val_acc\n", " print(f'Epoch: {epoch:03d}, Train Loss: {train_loss:.4f}, '\n", " f'Val Acc: {val_acc:.4f}, Test Acc: {test_acc:.4f}')" ] }, { "cell_type": "code", "execution_count": null, "id": "749420ff", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "df48fe17", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" } }, "nbformat": 4, "nbformat_minor": 5 }