Post

CS224W - (14) Deep Generative Models for Graphs



CS224W - (14) Deep Generative Models for Graphs

들어가며

우리는 지금까지 그래프를 학습하는 방법에 대해서 알아보았습니다. 그러면 이런 그래프는 어떻게 생성되었을까요?

이제는 그래프 생성 모델(graph generative models) 을 이용해서 진짜 같은 그래프를 생성하고자 합니다. 그래프 생성 모델은 신약 개발, 소셜 네트워크 모델링 등에 많이 사용됩니다. 우리가 그래프 생성 모델에 대해 배워야 하는 이유는 다음과 같습니다.

  • 그래프에 대한 인사이트를 얻을 수 있음
  • 그래프가 앞으로 어떻게 바뀌어 나갈지 예측할 수 있음
  • 일반적인 그래프에도 시뮬레이션을 수행해볼 수 있음
  • 주어진 그래프가 정상인지 비정상인지 확인할 수 있음

그래프 생성 분야는 아래와 같은 과정을 통해 발전해 왔습니다.

  1. 실세계 그래프의 특성 이해
  2. 전통적인 그래프 생성 모델
  3. 딥러닝을 활용한 그래프 생성 모델

지금까지는 그래프를 임베딩하여 인코딩하는 과정에 대해 다루었다면 이번 포스트는 임베딩으로부터 그래프를 생성하는 디코딩 문제를 다루게 됩니다.

그래프 생성을 위한 머신러닝

그래프 생성 태스크는 크게 두 가지로 나뉩니다.

  1. 실제 같은 그래프 생성
    • 주어진 그래프와 유사한 형태의 그래프를 생성합니다.
  2. 목표 지향적 그래프 생성
    • 주어진 목적과 제약 조건에 맞는 그래프를 생성합니다.
    • 신약 분자 생성/최적화에서 많이 사용합니다.

우선 그래프로부터 생성 모델을 학습하는 것이 목표입니다. $p_\text{data}$는 우리가 잘 알지 못하는 데이터 분포이지만 여기에서 $x_i \sim p_\text{data}(\mathbf{x})$를 샘플링할 수 있습니다. $\theta$로 파라미터화된 모델 $p_\text{model}(\mathbf{x}; \theta)$이 있을 때 우리는 이 모델을 $p_\text{data}(\mathbf{x})$에 근사시키고자 합니다. 우리의 목표는 크게 두 개입니다.

  1. 밀도 추정(Density estimation)
    • $p_\text{model}(\mathbf{x}; \theta)$를 $p_\text{data}(\mathbf{x})$에 가깝도록 만들어야 합니다.
  2. 샘플링(Sampling)
    • $p_\text{model}(\mathbf{x}; \theta)$를 샘플링하여 그래프를 생성할 수 있어야 합니다.

우선 밀도 추정을 위해선 Maximum Likelihood Estimation을 사용해야 합니다.

\[\theta^\ast = \arg\max_\theta \mathbb{E}_{\mathbf{x} \sim p_\text{data}} \log p_\text{model}(\mathbf{x} \mid \theta)\]

관측된 데이터 $\mathbf{x}_i \sim p_\text{data}$에 대해 가능한 모든 파라미터 $\theta$ 중에서 $\sum_i \log p_\text{model}(\mathbf{x}_i, \theta^ast)$가 가장 높은 값을 갖도록 하는 $\theta^\ast$ 를 찾습니다. 다시 말해서 관측된 데이터를 생성했을 가능성이 가장 높은 모델을 찾습니다.

두 번째로 모델 $p_\text{model}(\mathbf{x}; \theta)$로부터 샘플링하기 위해선 일반적으로 정규 분포에서 간단한 노이즈를 샘플링합니다.

\[\mathbf{z}_i \sim N(0, 1)\]

그 다음 어떤 함수 $f(\cdot)$으로 노이즈를 변환합니다.

\[\mathbf{x}_i = f(\mathbf{z}_i; \theta)\]

이때 $\mathbf{x}_i$는 복잡한 분포를 따릅니다. 그리고 $f(\cdot)$은 deep neural network를 사용해서 학습합니다.

Auto-regressive model은 밀도 추정과 샘플링을 모두 수행합니다. 이때 Chain rule을 사용하는데요. 결합 분포(joint distribution)은 조건부 분포(conditional distributions)의 곱이 됩니다.

\[p_\text{model}(\mathbf{x}; \theta) = \prod_{t=1}^n p_\text{model}(x_t \mid x_1, \cdots, x_{t-1}; \theta)\]

여기서 $\mathbf{x}$가 벡터라면 $x_t$는 $t$ 번째 차원이 되고, $\mathbf{x}$가 만약 문장이라면 $x_t$는 $t$ 번째 단어가 됩니다. 그래프의 경우에는 $x_t$는 $t$ 번째 행동이 됩니다. 행동이란 노드를 추가한다거나 엣지를 추가하는 행위를 일컫습니다.

GraphRNN: 진짜같은 그래프 생성하기

Main idea of GraphRNN

GraphRNN은 위 그림처럼 노드와 엣지를 순차적으로 추가하여 그래프를 생성합니다.

노드 순서 $\pi$를 가진 그래프 $G$는 노드와 에지를 추가하는 시퀀스 $S^\pi$로 고유하게 매핑할 수 있습니다. 여기서 $S^\pi$는 노드 레벨과 엣지 레벨로 나뉩니다. 노드 레벨은 한 번에 하나의 노드를 추가합니다. 엣지 레벨은 존재하는 노드 사이에 엣지를 추가합니다. 그래프에 노드 오더링을 적용했을 때 시퀀스의 시퀀스가 나옵니다. 이때 노드 순서는 임의로 선택됩니다.

인접 행렬 관점에서 행은 노드 레벨의 시퀀스, 열은 엣지 레벨의 시퀀스를 의미합니다.

지금까지 우리는 그래프 생성 문제를 시퀀스 생성 문제로 변환했습니다. 이제 모델은 두 개의 과정이 필요합니다.

  1. 새로운 노드의 상태를 생성하는 것 (노드 레벨 시퀀스)
  2. 새로운 노드의 상태에 기반하여 노드의 엣지를 생성하는 것 (엣지 레벨 시퀀스)

이 과정을 수행하기 위해 Recurrent Neural Network (RNN) 을 사용합니다. RNN은 시퀀셜 데이터에 대해 사용합니다. 입력 시퀀스를 받아 hidden state를 업데이트하고, 이때 hidden state는 RNN의 입력값의 모든 정보를 요약하여 갖고 있습니다. 이 업데이트는 RNN 셀에서 수행됩니다.

An RNN cell

  • $s_t$: $t$ 스텝 후 RNN의 state
  • $x_t$: $t$ 스텝에서의 입력값
  • $y_t$: $t$ 스텝에서의 아웃풋

RNN 셀은 다음의 순서로 업데이트를 수행합니다.

  1. Hidden state 업데이트
    • $s_t = \sigma(W \cdot x_t + U \cdot s_{t-1})$
  2. 예측 반환
    • $y_t = V \cdot s_t$

An overview of GraphRNN

GraphRNN은 노드 레벨의 RNN과 엣지 레벨의 RNN이 있습니다. 노드 레벨의 RNN에서 엣지 레벨의 RNN을 위한 initial state를 생성하고 엣지 레벨의 RNN에서 새로운 노드가 이전 노드와 연결될 것인지 예측하는 2단계의 수행 절차는 기존 RNN과 비슷한 점이 있습니다.

RNN for sequence generation

RNN과 유사하게 시퀀스를 생성할 때 이전의 아웃풋을 그 다음 스텝의 입력값으로 사용합니다. 그리고 최초 입력 시퀀스의 초기화는 start of sequence token (SOS) 를 사용합니다. 이때 SOS는 모든 값이 0이나 1로 이루어진 벡터를 사용합니다. 그리고 시퀀스 생성의 중단은 end of sequence token (EOS) 를 사용합니다. EOS는 RNN의 아웃풋으로 나타나는데 만약 이 값이 0이면 생성을 계속 진행하고 1이면 생성을 중단합니다. 이 방법이 나쁜건 아니지만 모델이 너무 결정론적(deterministic) 입니다.

우리의 목적은 다음의 식을 모델링하는 것이었습니다.

\[p_\text{model}(\mathbf{x}; \theta) = \prod_{t=1}^n p_\text{model}(x_t \mid x_1, \cdots, x_{t-1}; \theta)\]

$y_t = p_\text{model}(x_t \mid x_1, \cdots, x_{t-1}; \theta)$ 라고 둘 때 우리는 $x_{t-1}$를 $y_t: x_{t+1} \sim y_t$로부터 샘플링을 해야 합니다. 그렇게 되면 RNN 각 셀에서의 결과값은 엣지에 대한 확률값이 되고 위 그림과 같은 형태가 됩니다.

Teacher forcing

학습 시에는 teacher forcing을 사용하여 위와 같이 RNN 셀의 입력값과 아웃풋을 실제값으로 바꿔줍니다. 그러면 전 단계에서 잘못된 값을 보내주더라도 올바른 값을 넣어줘서 잘못 학습되는 상황을 막습니다. 손실 함수는 binary cross entropy를 사용하여 다음의 식을 최소화하도록 합니다.

\[\mathcal{L} = - \left( y_1^\ast \log(y_1) + (1 - y^\ast) \log(1 - y_1) \right)\]

지금까지의 내용을 요약하자면 다음과 같습니다.

  1. 노드 레벨의 RNN을 실행하고 그 아웃풋을 이용하여 엣지 레벨의 RNN을 초기화합니다.
  2. 엣지 레벨의 RNN을 실행하여 이전 노드와 새로운 노드의 연결 여부를 예측합니다.
  3. 엣지 레벨의 RNN의 마지막 hidden state를 사용하여 다음 단계의 노드 레벨의 RNN을 실행합니다.
  4. 만약 엣지 레벨의 RNN의 아웃풋이 EOS라면 그래프 생성을 중단합니다.

그래프 생성 규모 키우기 및 평가하기

그래프 생성 규모 키우기

GraphRNN에선 어떤 노드든 기존에 있던 노드와 연결이 가능합니다. 만약 엣지 생성이 너무 많아진다면 전체 인접 행렬을 생성할 때 조밀한 행렬이 생성되고 복잡도가 늘어나게 됩니다.

특히 위와 같이 5번 노드를 생성하여 연결할 때 그림 오른쪽의 순서대로 생성하게 된다면 불필요하게 긴 생성 방법이 됩니다. 이런 문제를 해결하기 위해 BFS에 기반을 둔 노드 오더링으로 문제를 어느 정도 해결할 수 있습니다.

BFS 노드 오더링에 따르면 노드 1은 노드 4와 연결되어 있지 않기 때문에 노드 1의 이웃에 대한 탐색은 끝났고, 따라서 노드 5는 절대 노드 1과 연결될 일이 없습니다. 따라서 기존의 $n-1$ 번의 단계를 거치지 않고 두 번의 단계에 대한 메모리만 사용할 수 있습니다.

BFS reduces the number of steps for edge generation

BFS 노드 오더링을 사용하면 가능한 노드 오더링의 수를 줄일 수 있습니다. $O(n!)$ 에서 서로 다른 BFS 오더링 수만큼 줄일 수 있죠. 그리고 엣지 생성 단계도 줄일 수 있습니다.

평가 방안

생성한 그래프의 유사도는 어떻게 계산할 수 있을까요? 크게 두 가지로 시각적으로 유사한지 확인하거나 그래프의 통계량을 통해 유사도를 계산하는 방법이 있습니다.

시각적 유사도

그래프 통계량 유사도

조금 더 엄밀한 비교를 위해서는 그래프 통계량을 사용해야 합니다. 일반적인 그래프 통계량이라고 하면 degree의 분포, 군집 계수(clustering coefficient)의 분포, orbit count 통계량 등이 있습니다. 각 통계량은 모두 확률 분포입니다. 학습 그래프의 통계량 집합과 생성한 그래프의 통계량을 비교할 때 앞으로 다룰 두 가지 방법을 사용할 수 있습니다.

Earth Mover Distance (EMD)

첫 번째로 Earth Mover Distance (EMD) 는 두 분포 사이의 유사도를 비교할 때 사용합니다. EMD는 어떤 흙 무더기를 다른 흙 무더기로 옮기는 데 필요한 최소한의 노력을 측정하는 것과 비슷합니다. 위 그림과 같이 왼쪽의 분포에서 오른쪽 분포로 옮길 때 얼마나 많은 노력이 필요한지, 즉 얼마나 차이가 나는지를 측정합니다.

\[\text{WORK}(F, \mathbf{x}, \mathbf{y}) = \sum^m_{i=1} \sum^n_{j=1} f_{ij} d_{ij}\]

Maximum Mean Discrepancy (MMD)

두 번째로 Maximum Mean Discrepancy (MMD) 는 분포 사이의 거리를 피처의 평균 임베딩 사이의 거리로 표현하는 방법입니다.

\[\begin{aligned} \text{MMD}^2(p \| q) &= \mathbb{E}_{x, y \sim p}[k(x, y)] + \mathbb{E}_{x, y \sim q}[k(x, y)] \\ &-\mathbb{E}_{x \sim p, y \sim q} [k(x, y)] \end{aligned}\]


This post is licensed under CC BY 4.0 by the author.