CS224W - (7) Theory of Graph Neural Networks
GNN의 표현력은 어느 정도일까?
The graph assumes all the nodes share the same feature.
그래프에서 노드의 색상이 피처를 나타낸다고 가정해봅시다. 위 그림의 그래프는 모든 노드가 같은 피처를 갖고 있습니다. GNN은 이런 그래프에서 서로 다른 그래프 구조를 얼마나 잘 구분할까요?
우선 그래프에서 각 노드 주변의 국소적인 이웃 구조를 고려합니다.
- 노드 1과 노드 5는 서로 다른 node degree를 갖고 있기 때문에 다른 이웃 구조를 갖고 있음을 알 수 있습니다.
- 노드 1과 노드 4는 같은 node degree를 갖고 있습니다. 그럼에도 두 노드는 다른 이웃 구조를 갖고 있습니다. 두 노드의 이웃이 서로 다른 node degree를 갖고 있기 때문입니다.
- 노드 1과 노드 2는 동일한 이웃 구조를 갖고 있습니다. 두 노드는 그래프에서 대칭이기 때문입니다.
연산 그래프
서로 다른 그래프 구조를 잘 구분하기 위에서는 위의 많은 경우의 수를 모두 구분해줘야 합니다. GNN은 이런 국소적인 이웃 구조를 잘 잡아낼 수 있는데 그 원동력은 바로 연산 그래프(computational graph) 입니다.
Computational graph for node 1
GNN은 각 레이어에서 이웃 노드의 임베딩을 집계합니다. GNN은 이웃에 의해 정이된 연산 그래프를 통해 노드 임베딩을 생성하죠. 아까 그래프에서 노드 1에 대한 연산 그래프는 위 그림에서 오른쪽과 같습니다. 그런데 노드 2에 대한 연산 그래프는 어떻게 생겼을까요?
Computational graph for node 2
노드 1과 노드 2는 대칭인데다 모든 노드가 동일한 피처를 갖고 있기 때문에 동일한 연산 그래프가 나옵니다. GNN은 노드 번호를 신경쓰지 않으므로 GNN은 노드 1과 노드 2에 대해 동일한 노드 임베딩을 생성합니다. 즉 GNN은 노드 1과 노드 2를 구분하지 못합니다.
Computational graphs are identical to rooted subtree structures around each node.
이처럼 GNN의 노드 임베딩은 루트 서브트리 구조(rooted subtree structures) 를 잡아냅니다. 따라서 가장 표현력이 높은 GNN은 서로 다른 루트 서브트리 구조를 서로 다른 노드 임베딩에 매핑하게 됩니다. 이걸 수학적으로 표현하면 GNN의 매핑이 injective할 때 가장 표현력이 높다고 할 수 있습니다.
Subtrees of the same depth can be recursively characterized from the leaf nodes to the root nodes.
위 그림에서 알 수 있듯이 동일한 깊이의 서브트리는 리프 노드에서부터 루트 노드로, 즉 아래에서 위 방향으로 재귀적인 방식을 통해 특성화할 수 있습니다. 만약 GNN의 각 집계 단계에서 이웃 정보를 완벽하게 유지할 수 있다면 생성되는 노드 임베딩이 서로 다른 루트 서브트리를 구분할 수 있습니다. 다시 말해서 표현력이 뛰어난 GNN은 각 단계에서 injective한 이웃 집계 함수를 사용합니다.
가장 강력한 GNN 설계하기
GNN의 표현력은 GNN이 사용하는 이웃 집계 함수에 의해서 결정됩니다. 다시 말해서 표현력 높은 이웃 집계 함수를 사용할 수록 GNN의 표현력은 높아집니다. 위에서 언급했듯 그 집계 함수가 injective하면 되는데요. 이 부분에 대해서 보다 이론적으로 알아보도록 하겠습니다.
이웃 집계 함수는 멀티셋에 대한 함수(a function over a multi-set) 로 생각할 수 있습니다. 멀티셋이란 반복되는 원소들로 구성된 집합입니다. 이 관점에서 GCN과 GraphSAGE를 살펴보겠습니다.
GCN
GCN은 이웃 집계 함수로 평균 풀링(mean-pool)을 사용하고 그 다음 선형 함수와 ReLU 활성화 함수를 사용합니다.
Theorem [Xu et al. ICLR 2019]
GCN의 집계 함수는 동일한 색깔 비율을 가진 멀티셋을 구분하지 못합니다.
위 그림과 같은 상황에서 GCN의 집계 함수인 평균 풀링은 두 멀티셋을 구분할 수 없습니다. 평균은 즉 비율이기 때문에 같은 비율을 갖고 있는 멀티셋은 같은 평균값을 반환하기 때문입니다.
GraphSAGE
GraphSAGE 는 MLP를 적용한 다음 맥스 풀링(max-pool)을 적용합니다.
Theorem [Xu et al. ICLR 2019]
GraphSAGE의 집계 함수는 동일한 고유 색상으로 구성된 서로 다른 멀티셋을 구분하지 못합니다.
위 그림과 같이 모든 멀티셋이 동일한 고유 색상을 갖고 있는 경우 맥스 풀링을 적용할 때 동일한 결과를 반환합니다.
최상의 표현력을 가진 GNN 디자인하기
이처럼 두 모델의 이웃 집계 함수는 모두 injective하지 않습니다. 따라서 GCN과 GraphSAGE는 모두 최상의 표현력을 가지지 못하죠.
Theorem [Xu et al. ICLR 2019]
\[\Phi\left( \sum_{x \in S} f(x) \right)\]
Any injective multi-set function can be expressed aswhere $\Phi$ and $f$ are non-linear functions.
위 정리에 대한 증명을 직관적으로 보기 위해서 $f$가 색깔에 대한 원 핫 인코딩을 생성한다고 가정하겠습니다. 이러한 원 핫 인코딩의 합은 입력된 멀티셋에 대한 모든 정보를 유지합니다.
최상의 표현력을 가진 GNN을 디자인하기 위해서 이웃 집계 함수로 뉴럴 네트워크를 써보면 어떨까요?
Theorem (Universal Approximation Theorem)
적절한 비선형성(ReLU or 시그모이드)을 가진 충분히 큰 차원의 1개 히든 레이어의 MLP는 모든 연속 함수를 임의의 정확도로 근사할 수 있다.
위 정리를 이용하여 $\Phi$와 $f$를 정할 수 있습니다.
\[\text{MLP}_\Phi \left( \sum_{x \in S} \text{MLP}_f(x) \right)\]이때 히든 레이어의 차원은 보통 100에서 500 정도면 충분합니다. 위와 같은 집계 함수를 사용하는 GNN이 바로 Graph Isomorphism Network(GIN) 입니다. 위 이웃 집계 함수는 injective하기 때문에 GCN이나 GraphSAGE처럼 실패하는 경우 없이 최상의 표현력을 가집니다.
Graph Isomorphism Network (GIN)
이제 GIN에 대해서 더 자세히 알아보도록 하겠습니다. GIN은 이전 포스트들에서 언급했던 Weisfeiler-Lehman Kernel과 큰 연관이 있습니다. WL 커널은 color refinement algorithm을 활용하고, 다음 식을 활용합니다.
\[c^{(k+1)}(v) = \text{HASH} \left( \left\{ c^{(k)}(v), \left\{ c^{(k)}(u) \right\}_{u \in N(v)} \right\} \right).\]GIN은 신경망을 사용해 위의 injective한 해시 함수를 모델링합니다. 위 식에서 $c^{(k)}(v)$가 로트 노드의 피처, ${ c^{(k)}(u) }_{u \in N(v)}$가 이웃 노드의 색이 됩니다.
Theorem [Xu et al. ICLR 2019]
\[\left(c^{(k)}(v), \{ c^{(k)}(u) \}_{u \in N(v)} \right)\]
Any injective function over the tuplecan be modeled as
\[\text{MLP}_\Phi \left( (1 + \epsilon) \cdot \text{MLP}_f (c^{(k)}(v)) + \sum_{u \in N(v)} \text{MLP}_f (c^{(k)}(u)) \right).\]
위에서 $f$는 원 핫 인코딩을 하는 함수라고 하였으니 모두 더하는 합 연산은 injective합니다. 이제 $\text{MLP}_\Phi$는 원 핫 입력 피처를 다음 레이어에 제공하는 역할을 합니다.
결국 GIN은 다음의 식을 이용해 노드 벡터를 반복적으로 업데이트합니다.
\[c^{(k+1)}(v) = \text{GINConv} \left( \left\{ c^{(k)}(v), \left\{ c^{(k)}(u) \right\}_{u \in N(v)} \right\} \right).\]여기서 $\text{GINConv}$는 서로 다른 입력을 다른 임베딩으로 매핑해주는 미분 가능한 해시 함수가 됩니다. GIN을 $K$번 반복하면 $c^{(k)}(v)$가 $K$-hop neighborhood의 구조를 요약하게 됩니다.
지금까지의 내용을 정리하면 GIN은 WL 커널의 미분 가능한 뉴럴 네트워크 버전이라고 할 수 있습니다. 추가로 GIN이 WL 그래프 커널 대비 갖는 장점도 있습니다.
- 노드 임베딩은 저차원이므로 다른 노드와의 유사도를 계산하기에 용이합니다.
- MLP가 학습 가능한 파라미터로 구성되어 있기 때문에 다운스트림 태스크에 맞춰 파인 튜닝도 가능합니다.
GIN과 WL 커널의 관계로 인해 이 두 가지의 표현력은 동일한 수준입니다. 만약 GIN을 이용해 두 그래프를 구분할 수 있다면 역시 WL 커널로도 구분할 수 있게 됩니다. 반대도 동일하구요. WL 커널은 1992년에 이론적으로나 경험적으로나 실세계의 대부분 그래프를 구분할 수 있다고 증명되었습니다. 따라서 GIN도 역시 대부분의 그래프를 구분한다고 할 수 있죠.
일이 계획대로 진행되지 않을 때
일반적인 팁
- 데이터 전처리가 매우 중요합니다.
- 노드 어트리뷰트는 매우 다양할 수 있기 때문에 이런 경우 정규화를 실시합니다.
- 옵티마이저로는 ADAM이 learning rate에 대해 강건한 모습을 보입니다.
- 활성화 함수
- ReLU가 보통은 잘 작동합니다.
- LeakyReLU나 PReLU가 좋은 대안이 될 수 있습니다.
- 마지막 아웃풋 레이어에는 활성화 함수를 넣지 않습니다.
- 모든 레이어에 bias term을 넣습니다.
- 임베딩 차원
- 32, 64, 128이 좋은 시작점이 됩니다.
네트워크 디버깅
- 손실 함수나 정확도가 학습 중 수렴하지 않는다면
- 파이프라인을 확인해보세요.
- Learning rate 같은 하이퍼파라미터를 조정해보세요.
- 가중치 파라미터의 초기화에 주의해주세요.
- 손실 함수를 잘 살펴보세요.
- 모델 개발에서 중요한 점
- 학습 데이터에 대한 오버피팅
- 학습/검증 손실 함수 값에 대한 모니터링