Image Classification하면 대표적으로 생각나는 모델이 CNN(Convolution Neural Network)이다. 이미지 분류를 GCN(Graph Convolution Neural Network)을 통해서도 가능하며 이번에 가장 기본적으로 사용되는 MNIST Dataset을 GCN을 통해 이미지 분류를 하였다. pytorch에서 제공해주는 mnist dataset과 GCNConv layer를 이용하였다. 먼저, pytorch에서 제공해주는 GCNConv 알고리즘에 대해 공부한 내용을 정리해 보고자 한다.
Graph로 구조화된 mnist dataset과 GCNConv layer를 사용하기 위해서는 다음 package을 설치해줘야 한다.
!pip install --verbose --no-cache-dir torch-scatter
!pip install --verbose --no-cache-dir torch-sparse
!pip install --verbose --no-cache-dir torch-cluster
!pip install torch-geometric
!pip install tensorboardX
Graph Convolution Neural Network
pytorch에서 GCN Layer를 "Semi-supervised Classification with Graph Convolutional Networkx" 논문을 토대로 형성하였다고 한다. 이 논문과 관련하여 공부한 내용을 토대로 정리한 블로그이다.
https://iy322.tistory.com/12
GCN Layer를 형성하는데 필요한 library는 다음과 같다.
from typing import Optional, Tuple
import torch
from torch import Tensor
from torch.nn import Parameter
from torch_scatter import scatter_add
from torch_sparse import SparseTensor, fill_diag, matmul, mul
from torch_sparse import sum as sparsesum
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.nn.inits import zeros
from torch_geometric.typing import Adj, OptTensor, PairTensor
from torch_geometric.utils import add_remaining_self_loops
from torch_geometric.utils.num_nodes import maybe_num_nodes
GCN Model
관련 기호 정리는 다음과 같다.
- D : Degree Matrix
- A : Adjacency Matrix
- H : Hidden Layer
- W : Weight
Update 과정.
Adjacency Matrix 생성 함수 : gcn_norm
GCN의 가장 큰 특징인 adjacency matrix를 형성해야 한다.
def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False,
add_self_loops=True, flow="source_to_target", dtype=None):
fill_value = 2.if improvedelse 1.
if isinstance(edge_index, SparseTensor):
assert flow in ["source_to_target"]
adj_t = edge_index
if not adj_t.has_value():
adj_t = adj_t.fill_value(1., dtype=dtype)
if add_self_loops:
adj_t = fill_diag(adj_t, fill_value)
deg = sparsesum(adj_t, dim=1)
deg_inv_sqrt = deg.pow_(-0.5)
deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0.)
adj_t = mul(adj_t, deg_inv_sqrt.view(-1, 1))
adj_t = mul(adj_t, deg_inv_sqrt.view(1, -1))
return adj_t
else:
assert flow in ["source_to_target", "target_to_source"]
num_nodes = maybe_num_nodes(edge_index, num_nodes)
if edge_weight is None:
edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype,
device=edge_index.device)
if add_self_loops:
edge_index, tmp_edge_weight = add_remaining_self_loops(
edge_index, edge_weight, fill_value, num_nodes)
assert tmp_edge_weight is not None
edge_weight = tmp_edge_weight
row, col = edge_index[0], edge_index[1]
idx = colif flow == "source_to_target"else row
deg = scatter_add(edge_weight, idx, dim=0, dim_size=num_nodes)
deg_inv_sqrt = deg.pow_(-0.5)
deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0)
return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
핵심적인 부분만 따로 설명을 하면 다음과 같다.
GCNConv Class
- Args
- in_channels(int)
- out_channels(int)
- improved
- cached
- add_self_loops(default=True)
- normalize(default=True) -> symmetric normalization adjacency matrix
- bias(default=True)
- shapes
- input: node_features(n_nodes*n_input_features) & edge_indices(2*n_edge_index) & edge_weight(n_edge_index)
- output : node_feature(n*nodes*n_output_features)
<객체변수 생성>
<Key Point>
위의 설명을 토대로 Graph Convolution Neural Network의 class는 다음과 같다.
class GCNConv(MessagePassing):
def __init__(self, in_channels: int, out_channels: int,
improved: bool = False, cached: bool = False,
add_self_loops: bool = True, normalize: bool = True,
bias: bool = True, **kwargs):
kwargs.setdefault('aggr', 'add')
super().__init__(**kwargs)
self.in_channels = in_channels
self.out_channels = out_channels
self.improved = improved
self.cached = cached
self.add_self_loops = add_self_loops
self.normalize = normalize
self._cached_edge_index = None
self._cached_adj_t = None
self.lin = Linear(in_channels, out_channels, bias=False,
weight_initializer='glorot')
if bias:
self.bias = Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
self.lin.reset_parameters()
zeros(self.bias)
self._cached_edge_index = None
self._cached_adj_t = None
def forward(self, x: Tensor, edge_index: Adj,
edge_weight: OptTensor = None)
if self.normalize:
if isinstance(edge_index, Tensor):
cache = self._cached_edge_index
if cache is None:
edge_index, edge_weight = gcn_norm( # yapf: disable
edge_index, edge_weight, x.size(self.node_dim),
self.improved, self.add_self_loops, self.flow, x.dtype)
if self.cached:
self._cached_edge_index = (edge_index, edge_weight)
else:
edge_index, edge_weight = cache[0], cache[1]
elif isinstance(edge_index, SparseTensor):
cache = self._cached_adj_t
if cache is None:
edge_index = gcn_norm( # yapf: disable
edge_index, edge_weight, x.size(self.node_dim),
self.improved, self.add_self_loops, self.flow, x.dtype)
if self.cached:
self._cached_adj_t = edge_index
else:
edge_index = cache
x = self.lin(x)
out = self.propagate(edge_index, x=x, edge_weight=edge_weight,
size=None)
if self.bias is not None:
out = out + self.bias
return out
def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor:
return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j
def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor:
return matmul(adj_t, x, reduce=self.aggr)
Image Classification
이제, pytorch geometric에서 제공해주는 MNISTSuperpixels dataset을 토대로 GCN 모델을 이용해 이미지 분류 학습을 해보고자 한다. 이 데이터셋은 "Geometric Deep Learning on Graphs and Manifolds Using Mixture Model CNNs" 논문을 토대로 만들었다고 한다. 이 논문은 아직 읽어보지 않아서 시간 내서 읽고 공부해 볼 예정이다.
- 필요한 라이브러리
import torch
import torch.nn.functional as F
from torch.nn import Linear, ReLU
from torch_geometric.nn import GCNConv, TopKPooling, global_mean_pool
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp
from torch_geometric.datasets import MNISTSuperpixels
from torch_geometric.data import DataLoader
import numpy as np
Dataset 불러오기
data_tr = MNISTSuperpixels(root='./', train=True) ## train set
data_te = MNISTSuperpixels(root='./', train=False) ## test set
MNISTSuperpixels의 데이터 구조는 다음과 같다.
총 이미지의 개수는 70,000개로 각 이미지마다 75개의 node와 최대 1,393개의 edge를 통해 graph를 형성하였다.
즉, 총 70,000개의 graph가 존재한다. 모든 graph는 10개의 class 중 하나로 labeling 되어있다.
print("train set", data_tr)
print("test set", data_te)
print("Number of Feature: ", data_tr.num_features)
print("Number of target: ", data_tr.num_classes)
### 결과
# train set MNISTSuperpixels(60000)
# test set MNISTSuperpixels(10000)
# Number of Feature: 1
# Number of target: 10
# 10번째 train set 확인
print("sample", data_tr[10])
print("nodes", data_tr[10].num_nodes)
print("edges", data_tr[10].num_edges)
## 결과
# sample Data(x=[75, 1], edge_index=[2, 1337], y=[1], pos=[75, 2])
# nodes 75
# edges 1337
## 각 이미지(그래프)는 객체변수 x, edge_index, y, pos 존재.
GCN Model Structure
GCN Structure은 다음과 같다.
GCN에 대한 더 자세한 내용은 아래 블로그를 통해 확인하면 된다.
https://iy322.tistory.com/7
channel_size = 64
class GCN_MODEL(torch.nn.Module):
## build model
def __init__(self):
super(GCN_MODEL, self).__init__()
## input channels, output channels, (add_self_loop = True, normalize = True) -> symmetric normalzied adjacency matrix.
self.conv1 = GCNConv(data_tr.num_features, channel_size)
self.conv2 = GCNConv(channel_size, channel_size*2)
self.conv3 = GCNConv(channel_size*2, channel_size)
self.conv4 = GCNConv(channel_size, channel_size)
self.fc = Linear(channel_size*2, data_tr.num_classes)
## model
def forward(self, x, edge_index, batch_index):
## input data : node features, edge indices(adjacency matrix)
## output : node features
input = self.conv1(x, edge_index)
input = F.relu(input)
hidden1 = self.conv2(input, edge_index)
hidden1 = F.relu(hidden1)
hidden2 = self.conv3(hidden1, edge_index)
hidden2 = F.relu(hidden2)
hidden3 = self.conv4(hidden2, edge_index)
hidden3 = F.relu(hidden3)
## pooling
hidden3 = torch.cat([gmp(hidden3, batch_index),
gap(hidden3, batch_index)], dim=1)
output = self.fc(hidden3)
return output, hidden3
model = GCN_MODEL()
print(model)
print("Number of parameters: ", sum(p.numel() for p in model.parameters()))
### 결과
# GCN_MODEL(
# (conv1): GCNConv(1, 64)
# (conv2): GCNConv(64, 64)
# (conv3): GCNConv(64, 64)
# (conv4): GCNConv(64, 64)
# (fc): Linear(in_features=128, out_features=10, bias=True)
# )
# Number of parameters: 13898
Training
data_size = len(data_tr)
## Cross EntropyLoss
## Optimizer : Adam
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0007)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)
batch_size = 64
train_loader = DataLoader(data_tr,
batch_size=batch_size, shuffle=True)
test_loader = DataLoader(data_te,
batch_size=batch_size, shuffle=True)
def train(data):
for batch in train_loader:
batch.to(device)
optimizer.zero_grad()
pred, embedding = model(batch.x.float(), batch.edge_index, batch.batch)
loss = torch.sqrt(loss_fn(pred, batch.y))
loss.backward()
optimizer.step()
return loss, embedding
losses = []
for epoch in range(500):
loss, h = train(data_tr)
losses.append(loss)
print(f"Epoch {epoch} | Train Loss {loss}")
# Visualize learning (training loss)
import seaborn as sns
value = [float(loss.cpu().detach().numpy()) for loss in losses]
index = [i for i,l in enumerate(value)]
vis = sns.lineplot(index, value)
vis
Train Loss 값을 시각화해서 확인하면 다음과 같다.
Test set -> Predict
import pandas as pd
test_batch = next(iter(test_loader))
with torch.no_grad():
num_total_data = 0
correct = 0
test_batch.to(device)
pred, embed = model(test_batch.x.float(), test_batch.edge_index, test_batch.batch)
pred = torch.nn.functional.softmax(pred, dim=1)
pred=torch.argmax(pred,dim=1)
첫번째 batch에서 정확도(accuracy)를 확인해 본 결과, CNN만큼 예측력이 좋은 거 같지는 않다. 코드를 더 보완해야 할 필요성이 있어 계속해서 공부해 가면서 보완해 볼 예정이다.
Node와 Edge를 통한 Data Visualization
마지막으로, node와 edge를 통해 graph 형성 후, 이미지 데이터를 시각화 해보았다.
시각화하는데는 spektral에서 제공해주는 package를 이용하였다.
node와 edge를 통해 graph를 형성한 후, node 100개만 뽑아서 확인해 본 결과 다음과 같다. 노드 간의 연결 여부는 자기 자신 포함해서 가장 가까운 거리에 있는 노드 8개와 연결되도록 만들어줬다.
52번 노드만 따로 빼서 확인해 보면, (23번, 24번, 25번, 51번, 53번, 79번, 80번, 81번) node와 연결된 것을 확인할 수 있다.
전체 node를 이용해 숫자 이미지를 시각화해 보았다.
숫자 3의 실제 이미지와 graph로 만든 이미지를 확인해 보면 다음과 같다.
전체 숫자 이미지를 다 보면 다음과 같다.
pytorch geometric에서 제공해주는 데이터셋은 mnist 말고도 더 많다. 틈틈히 하나하나 코드 짜보고 연습해 보면서 GCN에 대해 더 공부해 보고자 한다. dataset을 직접 graph로 구조화 시키고 gcn model로 돌릴 때까지 화이팅!
[자료] pytorch geometric : https://pytorch-geometric.readthedocs.io/en/latest/index.html