Post

CS224W - (4) Graph Neural Networks



CS224W - (4) Graph Neural Networks

들어가며

본 포스트에서는 그래프 뉴럴 네트워크(GNN) 을 기반에 둔 딥러닝 기법에 대해서 다룹니다. 지난번 포스트에서 다루었던 인코더로 GNN을 표현하자면 그래프 구조에 기반한 비선형 변환의 레이어를 겹겹이 쌓아둔 것이라고 할 수 있습니다. 이러한 구조를 이용해 노드 분류, 링크 예측, 커뮤니티 검출, 네트워크 유사도 등 다양한 다운스트림 태스크를 수행할 수 있습니다.

1장에서도 언급했듯 현대 딥러닝은 순서가 있는 데이터나 그리드가 있는, 즉 이미지나 텍스트를 처리하는 데에 최적화되어 있습니다. 하지만 네트워크처럼 임의의 크기를 갖고 있거나 복잡한 위상적 구조를 가진 경우 딥러닝으로 문제를 해결하기 어려웠습니다. 노드 순서가 고정되어 있지도 않고, 종종 동적인 멀티모달(multimodal) 피처를 사용하는 경우도 있으니까요.

그래프를 위한 딥러닝

몇 가지 설정을 하고 시작하겠습니다.

  • 그래프 $G$
  • 노드 집합 $V$
  • 인접 행렬 $A$는 이진 행렬로 구성
  • 노드 피처는 $|V| \times d$ 차원을 가진 행렬 $X \in \mathbb{R}^{|V| \times d}$
  • $v \in V$에 대해서 $v$의 이웃 집합을 $N(v)$로 정의

여기서 노드 피처는 도메인에 따라 다양하게 나타낼 수 있습니다. 만약 소셜 네트워크라면 유저 프로파일이나 유저 이미지가 될 수 있고, 생물학적 네트워크라면 유전자 표현형 프로파일, 유전 기능 정보 등이 될 수 있습니다.

A naïve approach

그래프를 입력값으로 하는 가장 단순한 접근법은 인접 행렬과 피처를 합친 다음 바로 DNN에 집어 넣는 것입니다. 하지만 당연히 여러 문제점이 있죠. 우선 파라미터 수가 $O(|V|)$가 되기 때문에 노드가 커지면 파라미터 수가 선형적으로 증가합니다. 그리고 서로 다른 크기의 그래프에는 적용할 수 없습니다. 그리고 노드 순서에도 예민한 네트워크가 되죠.

그래서 그래프를 일반적인 DNN이 아닌 CNN에 넣는 시도를 했습니다. 그럼에도 불구하고 그래프에 컨볼루션(convolution)을 적용하는 데에는 그래프에 지역성(locality)나 슬라이딩 윈도우(sliding window) 같은 개념이 없다는 문제가 있었습니다. 그리고 가장 큰 문제는 그래프는 순열 분별성(permutation invariant) 을 갖고 있다는 점이었습니다.

순열 불변성

Permutation invariance

그래프에서 노드에는 정해진 순서란 없습니다. 따라서 같은 구조의 그래프라면 그래프 표현의 관점에서 무조건 같은 결과가 나와야하죠. 위처럼 같은 구조에 노드의 이름만 바뀐 경우라면 그래프에 대한 표현과 노드에 대한 표현은 항상 같아야합니다. 조금 더 엄밀하게 말하자면 다음과 같습니다. 우리가 인접 행렬 $A$와 노드 피처 $X$에 대해 그래프 $G = (A, X)$를 $\mathbb{R}^d$로 매핑하는 함수 $f$를 학습한다고 하면 다음이 성립해야 합니다.

\[f(A_1, X_1) = f(A_2, X_2)\]

이때, 만약 모든 순서 $i$, $j$에 대해서 $f(A_i, X_i) = f(A_j, X_j)$를 만족한다면 함수 $f$는 순열 불변 함수(permutation invariant function) 입니다.

모든 그래프 함수 $f: \mathbb{R}^{|V| \times m} \times \mathbb{R}^{|V| \times |V|} \to \mathbb{R}^d$가 있을 때 어떤 순열 $P$에 대해서 $f(A, X) = f(PAP^T, PX)$를 만족하면 $f$는 순열 불변성(permutation-invariant) 을 갖는다.

순열 등변성

Permutation equivariant

그래프 표현의 관점에서 함수 $f$는 순열 불변 함수였지만 노드 표현의 관점에선 조금 다른 이야기를 하게 됩니다. 만약 노드 임베딩을 할 때 두 그래프에서 같은 노드의 위치가 다르다면 $f$의 결과로 나오는 벡터에서도 다른 위치에 등장하게 됩니다. 즉 같은 위치에 있는 노드는 같은 위치의 벡터에 등장하게 되는거죠.

노드 표현의 관점에서 그래프 $G = (A, X)$를 $\mathbb{R}^{m \times d}$로 매핑하는 함수 $f$를 학습한다고 가정하겠습니다. 그래프에서 같은 위치에 있는 노드의 아웃풋 벡터가 어떤 순서 변경에도 바뀌지 않은 채 유지된다면 우리는 그 함수 $f$를 순열 등변 함수 (permuatation equivariant function) 라고 부릅니다.

모든 노드 함수 $f: \mathbb{R}^{|V| \times m} \times \mathbb{R}^{|V| \times |V|} \to \mathbb{R}^{|V| \times m}$가 있을 때 어떤 순열 $P$에 대해서 $Pf(A, X) = f(PAP^T, PX)$를 만족하면 $f$는 순열 등변성(permutation-equivariant) 을 갖는다.

순열 불변성과 순열 등변성은 비슷하여 헷갈리는 경우가 많습니다. 예시를 통해서 이해를 해보도록 하겠습니다.

  • $f(A, X) = 1^T X$: 순열 불변
    • $f(PAP^T, PX) = 1^T PX = 1^T X = f(A, X)$
  • $f(A, X) = X$: 순열 등변
    • $f(PAP^T, PX) = PX = Pf(A, X)$
  • $f(A, X) = AX$: 순열 등변
    • $f(PAP^T, PX) = PAP^T PX = PAX = P f(A, X)$

Permutation invariant and Permutation equivariant

이 두 가지 특성은 다른 일반적인 뉴럴 네트워크 아키텍처와는 다르게 그래프 뉴럴 네트워크에 모두 포함되어 있습니다. 일반적인 뉴럴 네트워크는 인풋값의 순서에 변화를 주면 결과값이 무조건 바뀌게 됩니다. 이런 이유로 위에서 가장 단순한 접근법인 뉴럴 네트워크 기반 접근법은 그래프에 적용할 때 실패할 수 밖에 없습니다.

이제 그래프 뉴럴 네트워크에 필요한 특성은 알아냈고, 어떤 방법을 통해 이 특성을 구현할지에 대해서 이야기 해보도록 하겠습니다.

그래프 컨볼루션 네트워크

Main idea for GCN

그래프 컨볼루션 네트워크(Graph Convolution Network) 의 기본 아이디어는 노드의 이웃이 연산 그래프를 정의한다는 점입니다. 노드의 이웃은 노드에 대한 연산 그래프를 정의하고 정보를 전파하고 변환하는데, 어떤 방법으로 그래프 전체에 정보를 전파하여 노드 피처를 계산할 수 있을까요?

Aggregating neighbors

Every node defines a computation graph based on its neighborhood

우선 각 이웃의 정보를 집계하는 것이 중요합니다. 이를 위해 국소적인 네트워크 이웃을 기반으로 노드 임베딩을 생성할 수 있습니다. 이때 노드는 뉴럴 네트워크를 이용해서 이웃 노드로부터 정보를 집계합니다. 그리고 네트워크의 이웃이 연산 그래프를 정의한다는 아이디어에서 모든 노드에 대해 위와 같은 연산 그래프를 정의할 수 있습니다. 이때 위 그림처럼 각각의 연산 그래프는 서로 다른 구조를 가지게 됩니다.

Overview of a model

이렇게 만든 모델은 임의의 깊이를 가지게 됩니다. 노드는 각 레이어마다 임베딩을 갖게 되고, 노드 $v$에 대한 레이어 0의 임베딩은 입력 피처인 $x_v$가 됩니다. 그리고 레이어 $k$의 임베딩은 해당 노드로부터 $k$ 만큼 떨어져 있는 노드로부터의 정보를 갖고 있습니다.

이제 이웃 노드로부터 정보를 집계해야 하는데 이때 각각의 연산 그래프가 서로 다른 구조를 갖고 있으므로 집계한 결과는 여러 레이어에 걸쳐서 노드의 정보를 모으게 됩니다. 여기서 이웃 노드의 정보를 집계해 평균을 계산하는 것이 가장 간단한 방법이 됩니다. 그리고 마지막엔 뉴럴 네트워크를 적용하게 되죠. 위 그림에서는 레이어 2 이전에서 이웃 노드로부터 온 메시지를 평균 계산하고 레이어 2에서 뉴럴 네트워크를 적용합니다. 이 내용을 수식으로 나타낸다면 복잡하지만 아래와 같습니다.

\[\begin{aligned} h_v^0 & = x_v \\ h_v^{(k+1)} &= \sigma \left( W_k \sum_{u \in N(v)} \frac{h_u^{(k)}}{|N(v)|} + B_k h_v^{(k)} \right), \quad \forall k \in \{ 0, \cdots, K-1 \} \\ z_v &= h_v^{(K)} \end{aligned}\]
  • $h_v^0 = x_v$
    • 0 번째 레이어의 임베딩은 노드 피처와 같은 값으로 초기화합니다.
  • $\sigma$
    • ReLU 같은 비선형 활성함수를 사용합니다.
  • $\sum_{u \in N(v)} \frac{h_u^{(k)}}{|N(v)|}$
    • 이웃 노드의 이전 레이어 임베딩의 평균
  • $h_v^{(k)}$
    • $k$ 번째 레이어 에서의 노드 $v$의 임베딩
  • $K$
    • 전체 레이어의 수
  • $z_v$
    • $k$ 번째 레이어에서의 이웃 집계 후 임베딩
  • $W_k$
    • 이웃 집계에 대한 가중치 행렬
  • $B_k$
    • 히든 벡터 변환을 위한 가중치 행렬
  • 위 식에서 합 연산은 모두 순열 불변성을 갖는 풀링(pooling)이나 집계 방식입니다.

이런 방식의 GCN에서 주어진 노드에 대해 임베딩을 계산하는 연산은 순열 불변성을 가집니다. 위에서 언급했듯 이전 레이어 임베딩의 평균을 계산하는 등의 합 연산은 모두 순열 불변성을 가지죠. 그리고 그래프의 모든 노드를 고려할 때 GCN 연산은 순열 등변성을 가집니다.

Permutation-equivariance of GCN

순열 등변성을 갖는 이유

  1. 입력 노드 피처와 아웃풋 임베딩의 결과가 정렬되어 있습니다.
  2. GCN에서 주어진 노드에 대한 임베딩 연산은 순열 불변성을 갖습니다.
  3. 따라서 순서를 섞은 후 주어진 노드의 입력 노드 피처 행렬은 바뀌고, 임베딩 연산의 결과 위치는 동일합니다.

모델 학습

지금까지 정의한 GCN을 학습하기 위해선 임베딩에 대한 손실 함수를 정의해야 합니다. 우선 위 수식에서 학습 가능한 가중치 행렬은 $W_k$와 $B_k$ 입니다. 이 가중치 행렬과 최종 노드 임베딩인 $h_v^{(K)}$를 손실 함수에 넣고 SGD를 수행해서 가중치 파라미터를 학습할 수 있습니다.

희소 행렬(Sparse matrix) 연산을 통해 많은 집계 연산들을 효율적으로 수행할 수 있습니다. 우선 이웃 노드의 이전 레이어 임베딩의 평균을 행렬로 나타내보도록 하겠습니다. $H^{(k)} = \left[h_1^{(k)}, \cdots, h_{|V|}^{(k)}\right]^T$ 라고 할 때 다음이 성립합니다.

\[\sum_{u \in N_v} h_u^{(k)} = A_{v, :} H^{k}\]

그리고 $D$를 다음과 같이 diagonal matrix로 정의하겠습니다.

\[D_{v, v} = \text{Deg}(v) = |N(v)|\]

그렇다면 이 행렬의 역행렬인 $D^{-1}$도 반드시 diagonal하고 다음이 성립합니다.

\[D_{v, v}^{-1} = \frac{1}{|N(v)|}\]

따라서 이웃 노드의 이전 레이어 임베딩의 평균은 행렬로 다음과 같이 나타낼 수 있습니다.

\[\sum_{u \in N(v)} \frac{h_u^{(k-1)}}{|N(v)|} \implies H^{(k+1)} = D^{-1} A H^{(k)}\]

위 내용을 활용하면 다음과 같은 식을 얻을 수 있습니다.

\[\begin{aligned} h_v^{(k+1)} = H^{(k+1)} &= \sigma \left( W_k \sum_{u \in N(v)} \frac{h_u^{(k)}}{|N(v)|} + B_k h_v^{(k)} \right) \\ &= \sigma \left( \tilde{A} H^{(k)} W_k^T + H^{(k)}B_k^T \right) \quad \text{where } \tilde{A} = D^{-1}A \end{aligned}\]

마지막 식에서 첫 항은 이웃 집계에 대한 식이고 두 번째 항은 자기 변환에 대한 식입니다. 실제 연산 시 $\tilde{A}$가 희소 행렬이기 때문에 행렬 곱 연산을 효율적으로 수행할 수 있습니다. 주의해야 할 사항은 모든 GNN이 위처럼 행렬 형태로 쓰이지 않을 수 있다는 점입니다. 만약 집계 함수가 복잡한 형태라면 이렇게 행렬 형태로 쓰지 못할 수 있습니다.

만약 GNN을 학습할 때 지도 학습이라면 노드 임베딩 $z_v$와 노드 라벨 $y$에 대해서 다음의 손실 함수를 최소화하면 됩니다.

\[\min_\Theta \mathcal{L}(y, f(z_v))\]

이때 손실함수 $\mathcal{L}$은 $y$가 실수라면 L2, $y$가 범주형이라면 크로스 엔트로피를 사용할 수 있습니다. 지도 학습의 스키마는 다소 단순한 반면 비지도 학습의 경우 사용 가능한 라벨이 없어 조금 복잡합니다. 이때 그래프 구조를 활용하여 손실 함수를 정의할 수 있습니다.

지금까지 계속해서 강조했던 내용은 유사한 노드는 유사한 임베딩을 가진다는 점입니다. 그래서 비지도 학습의 경우 다음과 같이 손실 함수를 정의합니다.

\[\mathcal{L} = \sum_{z_u, z_v} \text{CE} (y_{u, v}, \text{DEC}(z_u, z_v))\]
  • $y_{u, v}$ : 노드 유사도
    • 랜덤 워크, 행렬 분해 등의 노드 유사도를 사용합니다.
  • $\text{DEC}$ : 내적 같은 디코더

위 내용을 정리하여 모델을 디자인하는 방법을 요약해서 살펴보겠습니다.

  1. 이웃 집계 함수를 정의한다.
  2. 임베딩에 대한 손실 함수를 정의한다.
  3. 노드 집합, 즉 연산 그래프 배치에 대해 학습을 진행한다.
  4. 필요에 따라 노드에 대한 임베딩을 생성한다.

귀납 가능성(Inductive Capability)

Shared parameters in GNN

모든 노드에 대해서 동일한 집계 파라미터가 공유됩니다. 모델 파라미터의 수는 $|V|$에 대해 선형 이하이며, 확인하지 않은 노드로도 일반화가 가능합니다. 즉 하나의 그래프에 대해서 학습을 하면 새로운 그래프에 대해서도 일반화할 수 있습니다. 예를 들어 유기체 A에 대한 단백질 상호작용 그래프를 학습하고, 새로 수집한 유기체 B에 대한 데이터에 대해 임베딩을 생성할 수 있게 되는거죠. 대부분의 시나리오에서 확인하지 않은 노드는 계속 생성됩니다. 레딧이나 유튜브 등에서도 마찬가지죠.

GNN과 CNN의 비교

CNN with 3x3 filter

위와 같은 CNN을 수식으로 쓰면 다음과 같습니다.

\[h_v^{(l+1)} = \sigma \left( \sum_{u \in N(v) \cup \{ v \}} W_l^u h_u^{(l)} \right), \quad \forall l \in \{ 0 , \cdots, L-1 \}\]

이때 $N(v)$는 픽셀 $v$의 여덟 개의 이웃 픽셀이 됩니다.

A graph with 8 neighbors

그리고 위와 같은 그래프에 대한 GNN을 수식으로 나타내면 다음과 같습니다.

\[h_v^{(l+1)} = \sigma \left( W_l \sum_{u \in N(v)} \frac{h_u^{(l)}}{|N(v)|} + B_l h_v^{(l)} \right), \quad \forall l \in \{0, \cdots, L-1 \}\]

CNN의 수식을 GNN과의 비교를 위해 다시 써보겠습니다.

\[h_v^{(l+1)} = \sigma \left( \sum_{u \in N(v)} W_l^u h_u^{(l)} + B_l h_v^{(l)} \right), \quad \forall l \in \{ 0 , \cdots, L-1 \}\]

GNN이 가진 CNN과의 가장 큰 차이는 이미지의 픽셀 $v$가 주어졌을 때 서로 다른 이웃 $u$에 대해 서로 다른 $W_l^{u}$를 학습할 수 있다는 점입니다. 중앙 픽셀에 대한 상대 위치를 사용해 아홉 개의 이웃에 대한 순서를 마음대로 고를 수 있기 때문이죠.

GNN과 CNN의 차이는 다음과 같이 요약할 수 있습니다.

  • CNN은 고정된 이웃 크기와 순서를 가진 특수한 GNN으로 볼 수 있습니다.
    • 필터의 크기는 CNN에 의해 미리 정의됩니다.
    • GNN의 장점은 각 노드마다 다른 Degree를 가진 임의의 그래프를 처리한다는 점입니다.
  • CNN은 순열 불변하거나 순열 등변하지 않습니다.
    • 픽셀의 순서를 바꾸면 다른 결과를 반환합니다.


This post is licensed under CC BY 4.0 by the author.