티스토리 뷰

이번에는.. 어느 분야에나 적용되고 있는 Attention을 살펴보고자 한다!
조금 충격을 먹은게, GNN분야를 공부하는데, 튜토리얼에 Attention 개념이 나와서 충격받았다...
 
나름 생각한 것도 적어 놓을태니 Attention은 잘 살펴보길 바란다.
 
혹시, GCN, SAGE 공식 설명을 보지않은 분은 보고오는걸 추천한다.. 왜냐면.. 수식에 대한 거부감을 줄이기 위해....?
https://sims-solve.tistory.com/70

GNN 수학식 뜯어보기 - 2. Graph SAGE

이번 포스팅에서는 Graph SAGE라 불리는 layer는 어떻게 node를 업데이트 하는지 수식을 살펴보고자 한다. 혹시, GCN 수식을 보지 못한 사람은 첫번째 포스팅을 보고 수식에 대한 두려움을 없애고 오는

sims-solve.tistory.com

https://sims-solve.tistory.com/69

GNN 수학식 뜯어보기 - 1. GCN

GNN을 공부하다 보면 반드시 접하게 되는 3가지 레이어가 있다. 1. GCN 2. Graph SAGE 3. GAT(Graph Attation) 공부를 하다보면.. 수학식이 나오면 뭔가 좌절감이 들고 못할 것만 같은 생각이 드는데, 이번 기

sims-solve.tistory.com

별로 어렵지 않은 수식이니 누구나 이해할 수 있다..!
 
그럼 GAT의 간단한 설명을 한번 해보겠다.
 

위 그림처럼 설명할 수 있다고 생각한다.
즉, 노드 v(노드 0)를 업데이트 할때, 노드 u(노드 2,3,4,5)에 대한 가중치를 구하여 서로 다른 영향력으로 노드 v에 반영하겠다는 것.
 
그럼 얼마나 영향을 줄지 나타내는 Attention값은 어떻게 구할까?

이처럼 구하여 노드를 업데이트 할 때 사용한다고 한다.
하지만, 막상 DGL에 구현된 GAT layer는 조금 다르다.
GAT코드를 살펴보자면..

 def forward(self, graph, feat, get_attention=False):
        
            src_prefix_shape = dst_prefix_shape = feat.shape[:-1]
            h_src = h_dst = self.feat_drop(feat)
            
            #=======================1============================
            feat_src = feat_dst = self.fc(h_src).view(
                *src_prefix_shape, self._num_heads, self._out_feats)
          #=======================2============================
            el = (feat_src * self.attn_l).sum(dim=-1).unsqueeze(-1)
            er = (feat_dst * self.attn_r).sum(dim=-1).unsqueeze(-1)
            graph.srcdata.update({'ft': feat_src, 'el': el})
            graph.dstdata.update({'er': er})
            
            # compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively.
            graph.apply_edges(fn.u_add_v('el', 'er', 'e'))
            e = self.leaky_relu(graph.edata.pop('e'))
            
            # compute softmax
            #=======================3============================
            graph.edata['a'] = self.attn_drop(edge_softmax(graph, e))
            # message passing
            graph.update_all(fn.u_mul_e('ft', 'a', 'm'),
                             fn.sum('m', 'ft'))
            rst = graph.dstdata['ft']
            
            # bias
            if self.bias is not None:
                rst = rst + self.bias.view(
                    *((1,) * len(dst_prefix_shape)), self._num_heads, self._out_feats)
            # activation
            if self.activation:
                rst = self.activation(rst)

            if get_attention:
                return rst, graph.edata['a']
            else:
                return rst

가장 중요하다고 생각한 부분만 뽑아보았다. 1,2,3 주석을 넣은 곳을 자세히 보길 바란다.
 
'주석 1'은 위 넣어놓은 식에서 step1에 해당하는 곳이다. DGL에서도 별반 다르지 않다.
 
'주석 2'는 식에서 step2에 해당하는 부분이다. 하지만, 식과는 조금 다른 모습을 볼 수 있다.
식에서는 concat을 한 후 W를 곱하지만, DGL 코드에서는 W를 곱한 후, concat를 하는 형식이다.
(좀더 자세히 들어가면.. src 노드에 해당하는 피쳐와  dst 노드에 해당하는 피쳐를 구해 각 행마다 합을 하고, 구한 src,dst 값을 더해 softmax를 취해 attention 값을 구한다. +) 이렇게 하면 구하고자 하는 것은 똑같으나, 좀 더 효율적이라고 한다.)
 
'주석 3'은 식 step3에 해당하는 부분이다. graph와 e를 넘겨주어, 노드 v에 연결되어있는 노드 u에 대한 것만 가중치를 구한다.
 
그 후 구한 가중치를 feature에 dot(행렬곱)을 하고, 결과값으로 나온 피쳐들을 모두 더해주는 방식으로 노드 v를 업데이트 한다.
 
조금 복잡할 수 있다.. 처음에는 복잡하나, 시간이 지나면 이해가 될 것이다.
 
이처럼, Attention 기법을 통해 graph을 학습 시킬수 있는 방식이 있다!
DGL에서는 GAT layer를 이용해 손쉽게 이용가능하니, 시도해보는것을 추천한다.
 
 
*추가적으로 GAT의 Attention에 대해 말하고자 하는 것이 있다.
이 글은 GAT 레이어에 사용된 attn_l, attn_r에 대한 개인적인 생각이다. 넘겨도 무방함!
 
NLP에서 Attention 은 Q, K ,V를 이용하여 구하곤 했다.. 하지만, GNN에서 Attention은 Q,K,V개념이 나오지 않는다.
하지만 코드를 보면서 어쩌면 Q, K ,V가 존재하지 않을까 생각했다.
하나씩 말해보고자 한다.
 
Q > NLP에서는 단어 feature를 의미한다. GNN에서는 노드 feature가 되겠다. ( NLP에서와 동일)
 
K > NLP에서는 단어 feature를 K로도 사용하지만, GNN에서는 attn_l . attn_r이 그 역할을 한다고 생각한다.
DGL GAT 레이어를 보면, 피쳐 * (attn_l or attn_r)를 하는 코드를 볼 수 있는데, 이것이 NLP 에서 Q * K 를 하는 것이랑 비슷하다.
GAT레이어에서는 src*attn_l , dst * attn_r을 하는 것을 볼 수 있는데, 생각해보면 src,dst는 노드의 feature이므로, 학습가능한 attn_l,attn_r을 두어 'src', 'dst'라는 의미,방향성을 내포하는 벡터를 만들도록 학습하는 것이 아닌가 싶다.
 
V > Value는 NLP에서처럼 softmax를 거친 Attention score를 feature에 곱하여 이루어진다.(단, 엣지로 연결되어 있을경우만)
GAT에서도 fn.u_mul_e를 통해 동일한 의미의 역할을 수행한다.
 
어느 분야에서든 Attention의 큰 의미(가중치를 두어 참조하겠다)는 바뀌지 않고, Q,K,V를 사용하여 진행된다고 생각한다.
 

공지사항
최근에 올라온 글
최근에 달린 댓글
Total
Today
Yesterday
링크
«   2024/05   »
1 2 3 4
5 6 7 8 9 10 11
12 13 14 15 16 17 18
19 20 21 22 23 24 25
26 27 28 29 30 31
글 보관함