본문 바로가기

Graph(Graph Neural Network)

GCN - Image Classification

Image Classification하면 대표적으로 생각나는 모델이 CNN(Convolution Neural Network)이다. 이미지 분류를 GCN(Graph Convolution Neural Network)을 통해서도 가능하며 이번에 가장 기본적으로 사용되는 MNIST Dataset을 GCN을 통해 이미지 분류를 하였다. pytorch에서 제공해주는 mnist dataset과 GCNConv layer를 이용하였다. 먼저, pytorch에서 제공해주는 GCNConv 알고리즘에 대해 공부한 내용을 정리해 보고자 한다.

Graph로 구조화된 mnist dataset과 GCNConv layer를 사용하기 위해서는 다음 package을 설치해줘야 한다.

!pip install --verbose --no-cache-dir torch-scatter
!pip install --verbose --no-cache-dir torch-sparse
!pip install --verbose --no-cache-dir torch-cluster
!pip install torch-geometric
!pip install tensorboardX

 

Graph Convolution Neural Network

pytorch에서 GCN Layer를 "Semi-supervised Classification with Graph Convolutional Networkx" 논문을 토대로 형성하였다고 한다. 이 논문과 관련하여 공부한 내용을 토대로 정리한 블로그이다.
https://iy322.tistory.com/12

 

Semi-Supervised Classification With Graph Convolutional Neworks

이번엔, GCN을 준지도학습에 적용한 논문에 대해 공부해 보았다. 논문 내용을 정리해 보고자 한다. semi-supervised learning에 대한 개념은 아래 블로그를 통해 확인할 수 있다. https://iy322.tistory.com/11 Sem

iy322.tistory.com

GCN Layer를 형성하는데 필요한 library는 다음과 같다.

from typing import Optional, Tuple

import torch
from torch import Tensor
from torch.nn import Parameter
from torch_scatter import scatter_add
from torch_sparse import SparseTensor, fill_diag, matmul, mul
from torch_sparse import sum as sparsesum

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.nn.inits import zeros
from torch_geometric.typing import Adj, OptTensor, PairTensor
from torch_geometric.utils import add_remaining_self_loops
from torch_geometric.utils.num_nodes import maybe_num_nodes

GCN Model

관련 기호 정리는 다음과 같다.

  • D : Degree Matrix
  • A : Adjacency Matrix
  • H : Hidden Layer
  • W : Weight

Update 과정.

Adjacency Matrix 생성 함수 : gcn_norm

GCN의 가장 큰 특징인 adjacency matrix를 형성해야 한다.

def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False,
             add_self_loops=True, flow="source_to_target", dtype=None):

    fill_value = 2.if improvedelse 1.

if isinstance(edge_index, SparseTensor):
assert flow in ["source_to_target"]
        adj_t = edge_index
if not adj_t.has_value():
            adj_t = adj_t.fill_value(1., dtype=dtype)
if add_self_loops:
            adj_t = fill_diag(adj_t, fill_value)
        deg = sparsesum(adj_t, dim=1)
        deg_inv_sqrt = deg.pow_(-0.5)
        deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0.)
        adj_t = mul(adj_t, deg_inv_sqrt.view(-1, 1))
        adj_t = mul(adj_t, deg_inv_sqrt.view(1, -1))
return adj_t

else:
assert flow in ["source_to_target", "target_to_source"]
        num_nodes = maybe_num_nodes(edge_index, num_nodes)

if edge_weight is None:
            edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype,
                                     device=edge_index.device)

if add_self_loops:
            edge_index, tmp_edge_weight = add_remaining_self_loops(
                edge_index, edge_weight, fill_value, num_nodes)
assert tmp_edge_weight is not None
            edge_weight = tmp_edge_weight

        row, col = edge_index[0], edge_index[1]
        idx = colif flow == "source_to_target"else row
        deg = scatter_add(edge_weight, idx, dim=0, dim_size=num_nodes)
        deg_inv_sqrt = deg.pow_(-0.5)
        deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0)
return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]

핵심적인 부분만 따로 설명을 하면 다음과 같다.

 

GCNConv Class

  • Args
    • in_channels(int)
    • out_channels(int)
    • improved
    • cached
    • add_self_loops(default=True)
    • normalize(default=True) -> symmetric normalization adjacency matrix
    • bias(default=True)

  • shapes
    • input: node_features(n_nodes*n_input_features) & edge_indices(2*n_edge_index) & edge_weight(n_edge_index)
    • output : node_feature(n*nodes*n_output_features)

<객체변수 생성>

<Key Point>


위의 설명을 토대로 Graph Convolution Neural Network의 class는 다음과 같다.

class GCNConv(MessagePassing):

    def __init__(self, in_channels: int, out_channels: int,
                 improved: bool = False, cached: bool = False,
                 add_self_loops: bool = True, normalize: bool = True,
                 bias: bool = True, **kwargs):

        kwargs.setdefault('aggr', 'add')
        super().__init__(**kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.improved = improved
        self.cached = cached
        self.add_self_loops = add_self_loops
        self.normalize = normalize

        self._cached_edge_index = None
        self._cached_adj_t = None

        self.lin = Linear(in_channels, out_channels, bias=False,
                          weight_initializer='glorot')

        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()

   def reset_parameters(self):
        self.lin.reset_parameters()
        zeros(self.bias)
        self._cached_edge_index = None
        self._cached_adj_t = None


   def forward(self, x: Tensor, edge_index: Adj,
                edge_weight: OptTensor = None) 

        if self.normalize:
            if isinstance(edge_index, Tensor):
                cache = self._cached_edge_index
                if cache is None:
                    edge_index, edge_weight = gcn_norm(  # yapf: disable
                        edge_index, edge_weight, x.size(self.node_dim),
                        self.improved, self.add_self_loops, self.flow, x.dtype)
                    if self.cached:
                        self._cached_edge_index = (edge_index, edge_weight)
                else:
                    edge_index, edge_weight = cache[0], cache[1]

            elif isinstance(edge_index, SparseTensor):
                cache = self._cached_adj_t
                if cache is None:
                    edge_index = gcn_norm(  # yapf: disable
                        edge_index, edge_weight, x.size(self.node_dim),
                        self.improved, self.add_self_loops, self.flow, x.dtype)
                    if self.cached:
                        self._cached_adj_t = edge_index
                else:
                    edge_index = cache

        x = self.lin(x)

        out = self.propagate(edge_index, x=x, edge_weight=edge_weight,
                             size=None)

        if self.bias is not None:
            out = out + self.bias

        return out


    def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor:
        return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j

    def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor:
        return matmul(adj_t, x, reduce=self.aggr)

 

Image Classification

이제, pytorch geometric에서 제공해주는 MNISTSuperpixels dataset을 토대로 GCN 모델을 이용해 이미지 분류 학습을 해보고자 한다. 이 데이터셋은 "Geometric Deep Learning on Graphs and Manifolds Using Mixture Model CNNs" 논문을 토대로 만들었다고 한다. 이 논문은 아직 읽어보지 않아서 시간 내서 읽고 공부해 볼 예정이다.

  • 필요한 라이브러리
import torch
import torch.nn.functional as F


from torch.nn import Linear, ReLU
from torch_geometric.nn import GCNConv, TopKPooling, global_mean_pool
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp

from torch_geometric.datasets import MNISTSuperpixels
from torch_geometric.data import DataLoader

import numpy as np

Dataset 불러오기

data_tr = MNISTSuperpixels(root='./', train=True) ## train set
data_te = MNISTSuperpixels(root='./', train=False) ## test set

MNISTSuperpixels의 데이터 구조는 다음과 같다.

총 이미지의 개수는 70,000개로 각 이미지마다 75개의 node와 최대 1,393개의 edge를 통해 graph를 형성하였다.
즉, 총 70,000개의 graph가 존재한다. 모든 graph는 10개의 class 중 하나로 labeling 되어있다.

print("train set", data_tr)
print("test set", data_te)
print("Number of Feature: ", data_tr.num_features)
print("Number of target: ", data_tr.num_classes)

### 결과
# train set MNISTSuperpixels(60000)
# test set MNISTSuperpixels(10000)
# Number of Feature:  1
# Number of target:  10
# 10번째 train set 확인
print("sample", data_tr[10])
print("nodes", data_tr[10].num_nodes)
print("edges", data_tr[10].num_edges)

## 결과
# sample Data(x=[75, 1], edge_index=[2, 1337], y=[1], pos=[75, 2])
# nodes 75
# edges 1337

## 각 이미지(그래프)는 객체변수 x, edge_index, y, pos 존재.

GCN Model Structure

GCN Structure은 다음과 같다.

GCN에 대한 더 자세한 내용은 아래 블로그를 통해 확인하면 된다.
https://iy322.tistory.com/7

 

Graph Convolution Network(GCN)

1. CNN 간략하게 설명 Convolution이란? Convolution Filter(kernel)를 이용하여 image를 순회하는 dot product를 계산하고, 기존의 데이터로부터 filter를 이용하여 새로운 tensor(=activation map)을 만드는 것. filter의

iy322.tistory.com

channel_size = 64

class GCN_MODEL(torch.nn.Module):
    
    ## build model
       def __init__(self):
            
            super(GCN_MODEL, self).__init__()
            
            ## input channels, output channels, (add_self_loop = True, normalize = True) ->  symmetric normalzied adjacency matrix.
            self.conv1 = GCNConv(data_tr.num_features, channel_size) 
            self.conv2 = GCNConv(channel_size, channel_size*2)
            self.conv3 = GCNConv(channel_size*2, channel_size)
            self.conv4 = GCNConv(channel_size, channel_size)
            
            self.fc = Linear(channel_size*2, data_tr.num_classes)
    
    ##  model 
       def forward(self, x, edge_index, batch_index):
            ## input data  : node features, edge indices(adjacency matrix)
            ## output : node features
            input = self.conv1(x, edge_index)
            input = F.relu(input)
            
            hidden1 = self.conv2(input, edge_index)
            hidden1 = F.relu(hidden1)
            
            hidden2 = self.conv3(hidden1, edge_index)
            hidden2 = F.relu(hidden2)
            
            hidden3 = self.conv4(hidden2, edge_index)
            hidden3 = F.relu(hidden3)
            
            ## pooling
            hidden3 = torch.cat([gmp(hidden3, batch_index),
                                gap(hidden3, batch_index)], dim=1)
            
            output = self.fc(hidden3)
            
            return output, hidden3
model = GCN_MODEL()
print(model)
print("Number of parameters: ", sum(p.numel() for p in model.parameters()))

### 결과
# GCN_MODEL(
#   (conv1): GCNConv(1, 64)
#   (conv2): GCNConv(64, 64)
#   (conv3): GCNConv(64, 64)
#   (conv4): GCNConv(64, 64)
#   (fc): Linear(in_features=128, out_features=10, bias=True)
# )
# Number of parameters:  13898

 

Training

data_size = len(data_tr)

## Cross EntropyLoss
## Optimizer : Adam
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0007)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)


batch_size = 64
train_loader = DataLoader(data_tr, 
                    batch_size=batch_size, shuffle=True)
test_loader = DataLoader(data_te, 
                         batch_size=batch_size, shuffle=True)
def train(data):
    
    for batch in train_loader:
        
        batch.to(device)
        
        optimizer.zero_grad()
        
        pred, embedding = model(batch.x.float(), batch.edge_index, batch.batch)
        
        loss = torch.sqrt(loss_fn(pred, batch.y))
        loss.backward()
        
        optimizer.step()
        
        return loss, embedding
losses = []
for epoch in range(500):
    loss, h = train(data_tr)
    losses.append(loss)
    print(f"Epoch {epoch} | Train Loss {loss}")



# Visualize learning (training loss)
import seaborn as sns
value = [float(loss.cpu().detach().numpy()) for loss in losses] 
index = [i for i,l in enumerate(value)] 
vis = sns.lineplot(index, value)
vis

Train Loss 값을 시각화해서 확인하면 다음과 같다.

Test set -> Predict

import pandas as pd 
test_batch = next(iter(test_loader))

with torch.no_grad():

    num_total_data = 0
    correct = 0
    
    test_batch.to(device)
    pred, embed = model(test_batch.x.float(), test_batch.edge_index, test_batch.batch) 
    pred = torch.nn.functional.softmax(pred, dim=1)
    pred=torch.argmax(pred,dim=1)

첫번째 batch에서 정확도(accuracy)를 확인해 본 결과, CNN만큼 예측력이 좋은 거 같지는 않다. 코드를 더 보완해야 할 필요성이 있어 계속해서 공부해 가면서 보완해 볼 예정이다.

Node와 Edge를 통한 Data Visualization

마지막으로, node와 edge를 통해 graph 형성 후, 이미지 데이터를 시각화 해보았다.
시각화하는데는 spektral에서 제공해주는 package를 이용하였다.

node와 edge를 통해 graph를 형성한 후, node 100개만 뽑아서 확인해 본 결과 다음과 같다. 노드 간의 연결 여부는 자기 자신 포함해서 가장 가까운 거리에 있는 노드 8개와 연결되도록 만들어줬다.


52번 노드만 따로 빼서 확인해 보면, (23번, 24번, 25번, 51번, 53번, 79번, 80번, 81번) node와 연결된 것을 확인할 수 있다.


전체 node를 이용해 숫자 이미지를 시각화해 보았다.
숫자 3의 실제 이미지와 graph로 만든 이미지를 확인해 보면 다음과 같다.


전체 숫자 이미지를 다 보면 다음과 같다.


pytorch geometric에서 제공해주는 데이터셋은 mnist 말고도 더 많다. 틈틈히 하나하나 코드 짜보고 연습해 보면서 GCN에 대해 더 공부해 보고자 한다. dataset을 직접 graph로 구조화 시키고 gcn model로 돌릴 때까지 화이팅!


[자료] pytorch geometric : https://pytorch-geometric.readthedocs.io/en/latest/index.html