{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Tutorial 9: Recurrent GNNs" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this tutorial we will implement an approximation of the Graph Neural Network Model (without enforcing contraction map) and analyze the GatedGraph Convolution of Pytorch Geometric." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.11.0\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/home/antonio/anaconda3/envs/geometric_new/lib/python3.9/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "import os\n", "import torch\n", "os.environ['TORCH'] = torch.__version__\n", "print(torch.__version__)\n", "\n", "!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html\n", "!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html\n", "!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "import os.path as osp\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import torch_geometric.transforms as T\n", "import torch_geometric\n", "from torch_geometric.datasets import Planetoid, TUDataset\n", "from torch_geometric.data import DataLoader\n", "from torch_geometric.nn.inits import uniform\n", "from torch.nn import Parameter as Param\n", "from torch import Tensor \n", "torch.manual_seed(42)\n", "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "device = \"cpu\"\n", "from torch_geometric.nn.conv import MessagePassing" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "dataset = 'Cora'\n", "transform = T.Compose([\n", " T.RandomNodeSplit('train_rest', num_val=500, num_test=500),\n", " T.TargetIndegree(),\n", "])\n", "path = osp.join('data', dataset)\n", "dataset = Planetoid(path, dataset, transform=transform)\n", "data = dataset[0]" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "dataset = 'Cora'\n", "path = osp.join('data', dataset)\n", "dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures())\n", "data = dataset[0]\n", "data = data.to(device)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Graph Neural Network Model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The MLP class is used to instantiate the transition and output functions as simple feed forard networks" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "class MLP(nn.Module):\n", " def __init__(self, input_dim, hid_dims, out_dim):\n", " super(MLP, self).__init__()\n", "\n", " self.mlp = nn.Sequential()\n", " dims = [input_dim] + hid_dims + [out_dim]\n", " for i in range(len(dims)-1):\n", " self.mlp.add_module('lay_{}'.format(i),nn.Linear(in_features=dims[i], out_features=dims[i+1]))\n", " if i+2 < len(dims):\n", " self.mlp.add_module('act_{}'.format(i), nn.Tanh())\n", " def reset_parameters(self):\n", " for i, l in enumerate(self.mlp):\n", " if type(l) == nn.Linear:\n", " nn.init.xavier_normal_(l.weight)\n", "\n", " def forward(self, x):\n", " return self.mlp(x)\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The GNNM calss puts together the state propagations and the readout of the nodes' states." ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "class GNNM(MessagePassing):\n", " def __init__(self, n_nodes, out_channels, features_dim, hid_dims, num_layers = 50, eps=1e-3, aggr = 'add',\n", " bias = True, **kwargs):\n", " super(GNNM, self).__init__(aggr=aggr, **kwargs)\n", "\n", " self.node_states = Param(torch.zeros((n_nodes, features_dim)), requires_grad=False)\n", " self.out_channels = out_channels\n", " self.eps = eps\n", " self.num_layers = num_layers\n", " \n", " self.transition = MLP(features_dim, hid_dims, features_dim)\n", " self.readout = MLP(features_dim, hid_dims, out_channels)\n", " \n", " self.reset_parameters()\n", " print(self.transition)\n", " print(self.readout)\n", "\n", " def reset_parameters(self):\n", " self.transition.reset_parameters()\n", " self.readout.reset_parameters()\n", " \n", " def forward(self): \n", " edge_index = data.edge_index\n", " edge_weight = data.edge_attr\n", " node_states = self.node_states\n", " for i in range(self.num_layers):\n", " m = self.propagate(edge_index, x=node_states, edge_weight=edge_weight,\n", " size=None)\n", " new_states = self.transition(m)\n", " with torch.no_grad():\n", " distance = torch.norm(new_states - node_states, dim=1)\n", " convergence = distance < self.eps\n", " node_states = new_states\n", " if convergence.all():\n", " break\n", " \n", " out = self.readout(node_states)\n", " \n", " return F.log_softmax(out, dim=-1)\n", "\n", " def message(self, x_j, edge_weight):\n", " return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j\n", "\n", " def message_and_aggregate(self, adj_t, x) :\n", " return matmul(adj_t, x, reduce=self.aggr)\n", "\n", " def __repr__(self):\n", " return '{}({}, num_layers={})'.format(self.__class__.__name__,\n", " self.out_channels,\n", " self.num_layers)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "MLP(\n", " (mlp): Sequential(\n", " (lay_0): Linear(in_features=32, out_features=64, bias=True)\n", " (act_0): Tanh()\n", " (lay_1): Linear(in_features=64, out_features=64, bias=True)\n", " (act_1): Tanh()\n", " (lay_2): Linear(in_features=64, out_features=64, bias=True)\n", " (act_2): Tanh()\n", " (lay_3): Linear(in_features=64, out_features=64, bias=True)\n", " (act_3): Tanh()\n", " (lay_4): Linear(in_features=64, out_features=64, bias=True)\n", " (act_4): Tanh()\n", " (lay_5): Linear(in_features=64, out_features=32, bias=True)\n", " )\n", ")\n", "MLP(\n", " (mlp): Sequential(\n", " (lay_0): Linear(in_features=32, out_features=64, bias=True)\n", " (act_0): Tanh()\n", " (lay_1): Linear(in_features=64, out_features=64, bias=True)\n", " (act_1): Tanh()\n", " (lay_2): Linear(in_features=64, out_features=64, bias=True)\n", " (act_2): Tanh()\n", " (lay_3): Linear(in_features=64, out_features=64, bias=True)\n", " (act_3): Tanh()\n", " (lay_4): Linear(in_features=64, out_features=64, bias=True)\n", " (act_4): Tanh()\n", " (lay_5): Linear(in_features=64, out_features=7, bias=True)\n", " )\n", ")\n", "Epoch: 001, Train Acc: 0.12857, Val Acc: 0.06800, Test Acc: 0.08800\n", "Epoch: 002, Train Acc: 0.14286, Val Acc: 0.25200, Test Acc: 0.25100\n", "Epoch: 003, Train Acc: 0.12143, Val Acc: 0.24400, Test Acc: 0.26100\n", "Epoch: 004, Train Acc: 0.17143, Val Acc: 0.20200, Test Acc: 0.20200\n", "Epoch: 005, Train Acc: 0.16429, Val Acc: 0.23000, Test Acc: 0.23100\n", "Epoch: 006, Train Acc: 0.22857, Val Acc: 0.10000, Test Acc: 0.10500\n", "Epoch: 007, Train Acc: 0.14286, Val Acc: 0.11400, Test Acc: 0.10000\n", "Epoch: 008, Train Acc: 0.14286, Val Acc: 0.08400, Test Acc: 0.08000\n", "Epoch: 009, Train Acc: 0.14286, Val Acc: 0.06800, Test Acc: 0.06500\n", "Epoch: 010, Train Acc: 0.18571, Val Acc: 0.18600, Test Acc: 0.17000\n", "Epoch: 011, Train Acc: 0.08571, Val Acc: 0.07800, Test Acc: 0.07700\n", "Epoch: 012, Train Acc: 0.14286, Val Acc: 0.08000, Test Acc: 0.08500\n", "Epoch: 013, Train Acc: 0.11429, Val Acc: 0.06800, Test Acc: 0.07000\n", "Epoch: 014, Train Acc: 0.15714, Val Acc: 0.10200, Test Acc: 0.08600\n", "Epoch: 015, Train Acc: 0.16429, Val Acc: 0.20000, Test Acc: 0.17500\n" ] }, { "ename": "KeyboardInterrupt", "evalue": "", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", "Input \u001b[0;32mIn [16]\u001b[0m, in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m epoch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m51\u001b[39m):\n\u001b[1;32m 29\u001b[0m train()\n\u001b[0;32m---> 30\u001b[0m accs \u001b[38;5;241m=\u001b[39m \u001b[43mtest\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 31\u001b[0m train_acc \u001b[38;5;241m=\u001b[39m accs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 32\u001b[0m val_acc \u001b[38;5;241m=\u001b[39m accs[\u001b[38;5;241m1\u001b[39m]\n", "Input \u001b[0;32mIn [16]\u001b[0m, in \u001b[0;36mtest\u001b[0;34m()\u001b[0m\n\u001b[1;32m 18\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mtest\u001b[39m():\n\u001b[1;32m 19\u001b[0m model\u001b[38;5;241m.\u001b[39meval()\n\u001b[0;32m---> 20\u001b[0m logits, accs \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m, []\n\u001b[1;32m 21\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m _, mask \u001b[38;5;129;01min\u001b[39;00m data(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtrain_mask\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mval_mask\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtest_mask\u001b[39m\u001b[38;5;124m'\u001b[39m):\n\u001b[1;32m 22\u001b[0m pred \u001b[38;5;241m=\u001b[39m logits[mask]\u001b[38;5;241m.\u001b[39mmax(\u001b[38;5;241m1\u001b[39m)[\u001b[38;5;241m1\u001b[39m]\n", "File \u001b[0;32m~/anaconda3/envs/geometric_new/lib/python3.9/site-packages/torch/nn/modules/module.py:1110\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1106\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1107\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1108\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1109\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1110\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1111\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1112\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", "Input \u001b[0;32mIn [15]\u001b[0m, in \u001b[0;36mGNNM.forward\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 26\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_layers):\n\u001b[1;32m 27\u001b[0m m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpropagate(edge_index, x\u001b[38;5;241m=\u001b[39mnode_states, edge_weight\u001b[38;5;241m=\u001b[39medge_weight,\n\u001b[1;32m 28\u001b[0m size\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[0;32m---> 29\u001b[0m new_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtransition\u001b[49m\u001b[43m(\u001b[49m\u001b[43mm\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 30\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[1;32m 31\u001b[0m distance \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mnorm(new_states \u001b[38;5;241m-\u001b[39m node_states, dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m)\n", "File \u001b[0;32m~/anaconda3/envs/geometric_new/lib/python3.9/site-packages/torch/nn/modules/module.py:1110\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1106\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1107\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1108\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1109\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1110\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1111\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1112\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", "Input \u001b[0;32mIn [14]\u001b[0m, in \u001b[0;36mMLP.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, x):\n\u001b[0;32m---> 17\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmlp\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/anaconda3/envs/geometric_new/lib/python3.9/site-packages/torch/nn/modules/module.py:1110\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1106\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1107\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1108\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1109\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1110\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1111\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1112\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", "File \u001b[0;32m~/anaconda3/envs/geometric_new/lib/python3.9/site-packages/torch/nn/modules/container.py:141\u001b[0m, in \u001b[0;36mSequential.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 139\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m):\n\u001b[1;32m 140\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m module \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m:\n\u001b[0;32m--> 141\u001b[0m \u001b[38;5;28minput\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[43mmodule\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 142\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28minput\u001b[39m\n", "\u001b[0;31mKeyboardInterrupt\u001b[0m: " ] } ], "source": [ "model = GNNM(data.num_nodes, dataset.num_classes, 32, [64,64,64,64,64], eps=0.01).to(device)\n", "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n", "loss_fn = nn.CrossEntropyLoss()\n", "\n", "\n", "test_dataset = dataset[:len(dataset) // 10]\n", "train_dataset = dataset[len(dataset) // 10:]\n", "test_loader = DataLoader(test_dataset)\n", "train_loader = DataLoader(train_dataset)\n", "\n", "def train():\n", " model.train()\n", " optimizer.zero_grad()\n", " loss_fn(model()[data.train_mask], data.y[data.train_mask]).backward()\n", " optimizer.step()\n", "\n", "\n", "def test():\n", " model.eval()\n", " logits, accs = model(), []\n", " for _, mask in data('train_mask', 'val_mask', 'test_mask'):\n", " pred = logits[mask].max(1)[1]\n", " acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()\n", " accs.append(acc)\n", " return accs\n", "\n", "\n", "for epoch in range(1, 51):\n", " train()\n", " accs = test()\n", " train_acc = accs[0]\n", " val_acc = accs[1]\n", " test_acc = accs[2]\n", " print('Epoch: {:03d}, Train Acc: {:.5f}, '\n", " 'Val Acc: {:.5f}, Test Acc: {:.5f}'.format(epoch, train_acc,\n", " val_acc, test_acc))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Gated Graph Neural Network" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "class GatedGraphConv(MessagePassing):\n", " \n", " def __init__(self, out_channels, num_layers, aggr = 'add',\n", " bias = True, **kwargs):\n", " super(GatedGraphConv, self).__init__(aggr=aggr, **kwargs)\n", "\n", " self.out_channels = out_channels\n", " self.num_layers = num_layers\n", "\n", " self.weight = Param(Tensor(num_layers, out_channels, out_channels))\n", " self.rnn = torch.nn.GRUCell(out_channels, out_channels, bias=bias)\n", "\n", " self.reset_parameters()\n", "\n", " def reset_parameters(self):\n", " uniform(self.out_channels, self.weight)\n", " self.rnn.reset_parameters()\n", "\n", " def forward(self, data):\n", " \"\"\"\"\"\"\n", " x = data.x\n", " edge_index = data.edge_index\n", " edge_weight = data.edge_attr\n", " if x.size(-1) > self.out_channels:\n", " raise ValueError('The number of input channels is not allowed to '\n", " 'be larger than the number of output channels')\n", "\n", " if x.size(-1) < self.out_channels:\n", " zero = x.new_zeros(x.size(0), self.out_channels - x.size(-1))\n", " x = torch.cat([x, zero], dim=1)\n", "\n", " for i in range(self.num_layers):\n", " m = torch.matmul(x, self.weight[i])\n", " m = self.propagate(edge_index, x=m, edge_weight=edge_weight,\n", " size=None)\n", " x = self.rnn(m, x)\n", "\n", " return x\n", "\n", " def message(self, x_j, edge_weight):\n", " return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j\n", "\n", " def message_and_aggregate(self, adj_t, x):\n", " return matmul(adj_t, x, reduce=self.aggr)\n", "\n", " def __repr__(self):\n", " return '{}({}, num_layers={})'.format(self.__class__.__name__,\n", " self.out_channels,\n", " self.num_layers)\n", "\n", "class GGNN(torch.nn.Module):\n", " def __init__(self):\n", " super(GGNN, self).__init__()\n", " \n", " self.conv = GatedGraphConv(1433, 3)\n", " self.mlp = MLP(1433, [32,32,32], dataset.num_classes)\n", " \n", " def forward(self):\n", " x = self.conv(data)\n", " x = self.mlp(x)\n", " return F.log_softmax(x, dim=-1)\n" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 001, Train Acc: 0.27143, Val Acc: 0.15800, Test Acc: 0.15400\n", "Epoch: 002, Train Acc: 0.35000, Val Acc: 0.22200, Test Acc: 0.22200\n", "Epoch: 003, Train Acc: 0.18571, Val Acc: 0.22400, Test Acc: 0.21000\n" ] }, { "ename": "KeyboardInterrupt", "evalue": "", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", "Input \u001b[0;32mIn [18]\u001b[0m, in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 26\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m accs\n\u001b[1;32m 29\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m epoch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m51\u001b[39m):\n\u001b[0;32m---> 30\u001b[0m \u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 31\u001b[0m accs \u001b[38;5;241m=\u001b[39m test()\n\u001b[1;32m 32\u001b[0m train_acc \u001b[38;5;241m=\u001b[39m accs[\u001b[38;5;241m0\u001b[39m]\n", "Input \u001b[0;32mIn [18]\u001b[0m, in \u001b[0;36mtrain\u001b[0;34m()\u001b[0m\n\u001b[1;32m 13\u001b[0m model\u001b[38;5;241m.\u001b[39mtrain()\n\u001b[1;32m 14\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mzero_grad()\n\u001b[0;32m---> 15\u001b[0m loss_fn(\u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m[data\u001b[38;5;241m.\u001b[39mtrain_mask], data\u001b[38;5;241m.\u001b[39my[data\u001b[38;5;241m.\u001b[39mtrain_mask])\u001b[38;5;241m.\u001b[39mbackward()\n\u001b[1;32m 16\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mstep()\n", "File \u001b[0;32m~/anaconda3/envs/geometric_new/lib/python3.9/site-packages/torch/nn/modules/module.py:1110\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1106\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1107\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1108\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1109\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1110\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1111\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1112\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", "Input \u001b[0;32mIn [17]\u001b[0m, in \u001b[0;36mGGNN.forward\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 58\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[0;32m---> 59\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconv\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 60\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmlp(x)\n\u001b[1;32m 61\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m F\u001b[38;5;241m.\u001b[39mlog_softmax(x, dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\n", "File \u001b[0;32m~/anaconda3/envs/geometric_new/lib/python3.9/site-packages/torch/nn/modules/module.py:1110\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1106\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1107\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1108\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1109\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1110\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1111\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1112\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", "Input \u001b[0;32mIn [17]\u001b[0m, in \u001b[0;36mGatedGraphConv.forward\u001b[0;34m(self, data)\u001b[0m\n\u001b[1;32m 33\u001b[0m m \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mmatmul(x, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mweight[i])\n\u001b[1;32m 34\u001b[0m m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpropagate(edge_index, x\u001b[38;5;241m=\u001b[39mm, edge_weight\u001b[38;5;241m=\u001b[39medge_weight,\n\u001b[1;32m 35\u001b[0m size\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[0;32m---> 36\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrnn\u001b[49m\u001b[43m(\u001b[49m\u001b[43mm\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 38\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m x\n", "File \u001b[0;32m~/anaconda3/envs/geometric_new/lib/python3.9/site-packages/torch/nn/modules/module.py:1110\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1106\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1107\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1108\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1109\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1110\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1111\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1112\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n", "File \u001b[0;32m~/anaconda3/envs/geometric_new/lib/python3.9/site-packages/torch/nn/modules/rnn.py:1267\u001b[0m, in \u001b[0;36mGRUCell.forward\u001b[0;34m(self, input, hx)\u001b[0m\n\u001b[1;32m 1264\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 1265\u001b[0m hx \u001b[38;5;241m=\u001b[39m hx\u001b[38;5;241m.\u001b[39munsqueeze(\u001b[38;5;241m0\u001b[39m) \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_batched \u001b[38;5;28;01melse\u001b[39;00m hx\n\u001b[0;32m-> 1267\u001b[0m ret \u001b[38;5;241m=\u001b[39m \u001b[43m_VF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgru_cell\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1268\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhx\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1269\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight_ih\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight_hh\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1270\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias_ih\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias_hh\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1271\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1273\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_batched:\n\u001b[1;32m 1274\u001b[0m ret \u001b[38;5;241m=\u001b[39m ret\u001b[38;5;241m.\u001b[39msqueeze(\u001b[38;5;241m0\u001b[39m)\n", "\u001b[0;31mKeyboardInterrupt\u001b[0m: " ] } ], "source": [ "device = \"cpu\"\n", "model = GGNN().to(device)\n", "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n", "loss_fn = nn.CrossEntropyLoss()\n", "\n", "\n", "test_dataset = dataset[:len(dataset) // 10]\n", "train_dataset = dataset[len(dataset) // 10:]\n", "test_loader = DataLoader(test_dataset)\n", "train_loader = DataLoader(train_dataset)\n", "\n", "def train():\n", " model.train()\n", " optimizer.zero_grad()\n", " loss_fn(model()[data.train_mask], data.y[data.train_mask]).backward()\n", " optimizer.step()\n", "\n", "\n", "def test():\n", " model.eval()\n", " logits, accs = model(), []\n", " for _, mask in data('train_mask', 'val_mask', 'test_mask'):\n", " pred = logits[mask].max(1)[1]\n", " acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()\n", " accs.append(acc)\n", " return accs\n", "\n", "\n", "for epoch in range(1, 51):\n", " train()\n", " accs = test()\n", " train_acc = accs[0]\n", " val_acc = accs[1]\n", " test_acc = accs[2]\n", " print('Epoch: {:03d}, Train Acc: {:.5f}, '\n", " 'Val Acc: {:.5f}, Test Acc: {:.5f}'.format(epoch, train_acc,\n", " val_acc, test_acc))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.12" } }, "nbformat": 4, "nbformat_minor": 4 }