티스토리 뷰

Pytorch

[pytorch] RNN 이해하기

Sims. 2023. 1. 4. 18:49

딥러닝 분야를 접하면 반드시 접하게 되는 CNN, RNN이 있다.

비전과 자연어의 기본 모델이라, 어떻게 작동이 되는지 처음부터 잘 알아두면 큰 틀의 과정을 잘 이해할 수 있다고 생각된다.

 

여기서는 RNN에 대해 이야기 해보려고 한다.

어떻게 사용하는지 사용법을 알고싶다면, 빠르게 다른 포스팅을 찾는게 더 좋을 수 있다.

 

정말 어떻게 계산되는지 설명해보고자 한다.

RNN의 이론을 보면 이와같은 이미지를 안볼수가 없다.

대부분 이런 이미지를 통해 RNN이라 설명한다. 하지만, 개인적으로 너무 설명이 생략된 것이 많다고 생각한다.

 

개인적으로 이런 그림이 보다 직관적이라 생각한다.

A,B,C,D는 단어 embedding (vector)이다. 하지만, 보지 못한 것이 하나 있다.

맨 왼쪽 '0' 이라는 노드.. 궁금하지 않았는가? RNN은 이전 상태 h(t-1)과 현재상태 h(t)를 가지고 연산을 하는데,

그럼 맨처음 (여기서는 A)의 이전상태(h(-1))는 무엇이 있어야 하는지 늘 궁금했다.

만약 없다면, 당연히 에러가 나올것이다. 이런 궁금증은 일반적인 강의는 설명해주지 않았다..(왜???)

 

결론은 A 단어의 vector size만큼의 0 vector를 만들어 계산하게 된다.

즉, A = [1,5,4,2,6] 라는 vector를 가지고 있으면, h(-1)은 [0,0,0,0,0]의 vector가 되는것이다.

이렇게 해주는 이유는.. 간단하다. 단순히 어떤 W(weight)를 곱하더라도 0 vector이기에 전혀 영향을 주지 않는다.

즉, A단어만 영향을 줄 수 있다는 것. 그 후로는 값이 0이 아니므로 우리가 알던 RNN의 방식으로 잘 진행된다.

 

여기까지 이해했다면, RNN은 다 이해했다고 해도 무방하다.

결국 f(h(t) + h(t-1) + b)의 식으로 우리가 잘 알고 있는 방식을 통해 연산을 진행한다.

 

그럼 손으로 계산하여 RNN을 완벽히 이해해보자.

단순 3개의 단어로 구성된 문장의 RNN을 진행한 것을 수기로 구해보았다. 그리고 정말 torch의 RNN을 이용한 결과값을 살펴보면..

거의 유사한 값이 나온 것을 볼 수 있다! (손 계산은 소숫점을 많이 없애다보니 조금의 오차가 존재함)

물론, 실험을 위해 동일한 weight를 사용하고, relu를 사용했고, bias는 넣지 않았다.

 

지금까지 RNN을 정말 완벽히 이해했다고 자부할 정도로 살펴보았다. 이 이해를 바탕으로 자연어 처리를 학습해 나가면 도움이 될것이다.

 

**참고로 코드의 결과가 왜 2가지가 나오냐면,, torch의 RNN은 return을 2가지 한다. 첫번째는 각 output을 모아놓은 것이고,

두번째는 맨 마지막 output만 나타내는 것이다.

이것을 이용하여 다대다, 다대일과 같은 task를 수행할 수 있다.

공지사항
최근에 올라온 글
최근에 달린 댓글
Total
Today
Yesterday
링크
«   2024/05   »
1 2 3 4
5 6 7 8 9 10 11
12 13 14 15 16 17 18
19 20 21 22 23 24 25
26 27 28 29 30 31
글 보관함