ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • [NLP] 오차역전파법 역전파 backpropagation
    ML/NLP 2020. 4. 6. 02:38

     

     

     

    앞서 신경망 학습에 대해 포스팅할 때는 가중치 매개변수의 기울기(가중치 매개변수에 대한 손실 함수의 기울기)를 수치 미분을 사용해 구했다.

     

    수치 미분은 단순하고 구현하기도 쉽지만 계산 시간이 오래 걸린다. 이번에는 가중치 매개변수의 기울기를 효율적으로 계산하는 '오차역전파법 backpropagation'을 배워보겠다.

     

    계산 그래프를 통해 공부해보겠다. 계산 그래프(computational graph)는 계산 과정을 그래프로 나타낸 것이다. 여기서 그래프는 우리가 잘 아는 그래프 자료구조로, 복수의 노드(node)와 에지(edge)로 표현된다.

     

     

     

     

    계산 그래프로 풀다


    문제1. 슈퍼에서 1개에 100원인 사과를 2개 샀습니다. 이때 지불 금액을 구하세요. 단, 소비세가 10% 부과됩니다.

     

    그림 1

    [그림 1]과 같이 처음에 사과의 100원이 'x2' 노드로 흐르고, 200원이 되어 다음 노드로 전달된다.

    이제 200원이 'x1.1' 노드를 220원이 된다. 따라서 이 계산 그래프에 따르면 최종 답은 220원이 된다.

     

     

     

    문제2. 슈퍼에서 사과를 2개, 귤 3개 샀습니다. 사과는 1개에 100원, 귤은 1개 150원입니다. 소비세 10%일 때 지불 금액을 구하세요.

    그림 2

    이 문제에는 덧셈 노드인 '+'가 새로 등장하여 사과와 귤의 금액을 합산한다. 계산 그래프는 왼쪽에서 오른쪽으로 계산을 진행한다.

    회로에 전류가 흐르듯 계산 결과가 왼쪽에서 오른쪽으로 전달된다고 생각하면 된다. 계산 결과가 오른쪽 끝에 도착하면 거기서 끝난다.

    그래서 [그림 2]에서의 답은 715원이다.

     

     

    계산 그래프를 이용한 문제풀이는 다음 흐름으로 진행한다.

    1. 계산 그래프를 구성한다.

    2. 그래프에서 계산을 왼쪽에서 오른쪽으로 진행한다.

     

    여기서 2번째 '계산을 왼쪽에서 오른쪽으로 진행'라는 단계를 순전파(forward propagation)라고 한다. 순전파는 계산 그래프의 출발점부터 종착점으로의 전파이다. 반대로 '오른쪽에서 왼쪽으로 진행' 단계는 역전파(backward propagation)라고 한다. 

     

    역전파는 이후에 미분을 계산할 때 중요한 역할을 한다.

     

     

     

    국소적 계산


    계산 그래프의 특징은 '국소적 계산'을 전파함으로써 최종 결과를 얻는다는 점에 있다.

     

    결국 전체에서 어떤 일이 벌어지든 상관없이 자신과 관계된 정보만으로 결과를 출력할 수 있다는 것이다.

     

    예를 들어보겠다. 가령 슈퍼마켓에서 사과 2개를 포함한 여러 식품을 구입하는 경우를 생각해보자.

     

    그림 3

    [그림 3]에서는 여러 식품을 구입하여 총금액이 4,000원이 되었다. 여기서 핵심은 각 노드에서의 계산은 국소적 계산이라는 점이다. 가령 사과와 그 외의 물품 값을 더하는 계산(4,000 + 200 -> 4,200)은 4,000이라는 숫자가 어떻게 계산되었느냐와는 상관없이, 단지 두 숫자를 더하면 된다는 뜻이다. 각 노드는 자신과 관련한 계산(이 예에서는 입력된 두 숫자의 덧셈) 외에는 아무것도 신경 쓸 게 없다.

     

    이처럼 계산 그래프는 국소적 계산에 집중한다. 전체 계산이 아무리 복잡해도 각 단계에서 하는 일은 해당 노드의 '국소적 계산'이다. 국소적인 계산은 단순하지만, 그 결과를 전달함으로써 전체를 구성하는 복잡한 계산을 해낼 수 있다.

     

     

     

     

    왜 계산 그래프로 푸는가?


    1. 국소적 계산: 전체가 아무리 복잡해도 각 노드에서는 단순한 계산에 집중하여 문제를 단순화할 수 있다.

    2. 중간 계산 결과를 모두 보관할 수 있다.

    3. 역전파를 통해 '미분'을 효율적으로 계산할 수 있다.

     

     

    계산 그래프의 역전파를 설명하기 위해 문제1을 다시 보겠다.

     

    문제 1은 사과를 2개 사서 소비세를 포함한 최종 금액을 구하는 것이다.

     

    여기서 가령 사과 가격이 오르면 최종 금액에 어떤 영향을 끼치는지 알고 싶다고 하자.

     

    이는 '사과 가격에 대한 지불 금액의 미분'을 구하는 문제에 해당한다. 기호로 나타낸다면 사과 값을 x, 지불 금액을 L이라 했을 때 L/x을 구하는 것이다.

     

    이 미분 값은 사과 값이 '아주 조금' 올랐을 때 지불 금액이 얼마나 증가하느냐를 표시한 것이다.

     

    역전파를 하면 '사과 가격에 대한 지분 금액의 미분'을 구할 수 있다.

     

    그림 4

    [그림 4]과 같이 역전파는 순전파와는 반대 방향의 화살표(굵은 선)로 그린다.

     

    이 전파는 '국소적 미분'을 전달하고 그 미분 값은 화살표의 아래에 적는다. 

     

    이 예에서 역전파는 오른쪽에서 왼쪽으로 '1 -> 1.1 -> 2.2' 순으로 미분 값을 전달한다.

     

    이 결과로부터 '사과 가격에 대한 지불 금액의 미분' 값은 2.2라 할 수 있습니다.

     

    사과가 1원 오르면 최종 금액은 2.2원 오른다는 뜻이다.

     

     

     

     

     

    연쇄법칙


    역전파는 '국소적인 미분'을 순방향과는 반대인 오른쪽에서 왼쪽으로 전달한다.

     

    또한, 이 '국소적 미분'을 전달하는 원리는 연쇄법칙(chain rule)에 따른 것이다.

     

    역전파의 예를 하나 살펴보자.

     

    그림 5

    [그림 5]과 같이 역전파의 계산 절차는 신호 E에 노드의 국소적 미분(x/y)을 구한다는 뜻이다.

     

    그리고 이 국소적인 미분을 상류에서 전달된 값(이 예에서는 E)에 곱해 앞쪽 노드로 전달하는 것이다.

     

    이것이 역전차의 계산 순서인데, 이러한 방식을 따르면 목표로 하는 미분 값을 효율적으로 구할 수 있다는 것이 이 전파의 핵심이다.

     

    왜 그런 일이 가능한가는 연쇄법칙의 원리로 설명할 수 있다.

     

     

     

     

     

    연쇄법칙이란?


    연쇄법칙을 설명하려면 우선 합성 함수부터 알아야 한다.

     

    합성 함수란 여러 함수로 구성된 함수이다. 예를 들어 z = (x+y)^2이라는 식은 [그림 6]과 같이 두 개의 식으로 구성된다.

     

    그림 6

    연쇄법칙은 합성 함수의 미분에 대한 성질이며, 다음과 같이 정의된다.

     

    합성 함수의 미분은 합성 함수를 구성하는 각 함수의 미분의 곱으로 나타낼 수 있다.

     

    이것이 연쇄법칙의 원리이다. z/x (x에 대한 z의 미분)은 z/∂t (t에 대한 z의 미분)과 ∂t/x (x에 대한 t의 미분)의 곱으로 나타낼 수 있다.

     

    수식으로 나타내면 다음과 같다.

     

    그림 7

     

     

     

     

    연쇄법칙과 계산 그래프


    연쇄법칙 계산을 계산 그래프로 나타내보자. 2제곱 계산을 '**2' 노드로 나타내면 [그림 8]처럼 그릴 수 있다.

     

    그림 8

    [그림 8]과 같이 계산 그래프의 역전파는 오른쪽에서 왼쪽으로 신호를 전파한다.

     

    역전파의 계산 절차에서는 노드로 들어온 입력 신호에 그 노드의 국소적 미분(편미분)을 곱한 후 다음 노드로 전달한다.

     

    예를 들어 '**2' 노드에서의 역전파를 보자.

     

    입력은 z/z이며, 이에 국소적 미분인 z/∂t (순전파 시에는 입력이 t이고출력이 z이므로 이 노드에서 (국소적)미분은 z/∂t이다)를 곱하고 다음 노드로 넘긴다.

     

    한 가지, [그림 8]에서 역전파의 첫 신호인 z/z의 값은 결국 1이라서 앞의 수식에서는 언급하지 않았다.

     

     

    주목할 것은 맨 왼쪽 역전파이다. 이 계산은 연쇄법칙에 따르면 z/z * z/∂t * ∂t/∂x= z/∂t * ∂t/∂x= z/∂x가 성립되어 'x에 대한 z의 미분'이 된다.

     

    즉, 역전파가 하는 일은 연쇄법칙의 원리와 같다는 것이다.

     

    [그림 8]에 [그림 9]의 결과를 대입하면 [그림 10]이 되며, z/∂x는 2(x+y) 임을 구할 수 있다.

     

    그림 9

     

    그림 10

     

     

     

     

     

     

    덧셈 노드의 역전파


    z = x+y라는 식을 대상으로 그 역전파를 살펴보자.

     

    우선 z = x+y의 미분은 다음과 같이 해석적으로 계산할 수 있다.

     

    그림 11

    [그림 11]에서와 같이 z/∂x와 z/∂y는 모두 1이 된다. 이를 계산 그래프로는 [그림 12]처럼 그릴 수 있다.

     

    그림 12

     

     

    [그림 12]와 같이 역전파 때는 상류에서 전해진 미분(이 예에서는  ∂L/∂z)에 1을 곱하여 하류로 흘린다.

     

    즉, 덧셈 노드의 역전파는 1을 곱하기만 할 뿐이므로 입력된 값을 그대로 다음 노드로 보내게 된다.

     

     

    이 예에서는 상류에서 전해진 미분 값을  ∂L/∂z이라 했는데, 이는 최종적으로 L이라는 값을 출력하는 큰 계산 그래프를 가정하기 때문이다.

     

    z = x + y 계산은 그 큰 계산 그래프의 중간 어딘가에 존재하고, 상류로부터 ∂L/∂z 값이 전해진 것이다.

     

    그리고 다시 하류로 ∂L/∂x과 ∂L/∂y 값을 전달하는 거다.

     

     

     

     

     

    곱셈 노드의 역전파


    z = xy라는 식을 생각해보죠. 이 식의 미분은 다음과 같다.

     

    그림 13

    [그림 13]에서 계산 그래프는 다음과 같이 그릴 수 있다.

     

    그림 14

    곱셈 노드 역전파는 상류의 값에 순전파 때의 입력 신호들을 '서로 바꾼 값'을 곱해서 하류로 보낸다.

     

    서로 바꾼 값이란 [그림 14]처럼 순전파 때 x였다면 역전파에서는 y, 순전파 때 y였다면 역전파에서는 x로 바꾼다는 의미다.

     

     

    덧셈의 역전파에서는 상류 값을 그대로 흘려보내서 순방향 입력 신호의 값을 필요하지 않았으나

     

    곱셈의 역전파는 순방향 입력 신호의 값이 필요하다.

     

    그래서 곱셈 노드를 구현할 때는 순전파의 입력 신호를 변수에 저장해둔다.

     

     

     

     

    사과 쇼핑의 예


    사과 쇼핑 예를 다시 보자.

     

    이 문제에서 사과의 가격, 사과의 개수, 소비세라는 세 변수 각각이 최종 금액에 어떻게 영향을 주느냐를 풀고자 한다.

     

    이는 '사과 가격에 대한 지불 금액의 미분', '사과 개수에 대한 지불 금액의 미분', '소비세에 대한 지불 금액의 미분'을 구하는 것에 해당한다.

     

    이를 계산 그래프의 역전파를 사용해서 풀면 [그림 15]처럼 된다.

     

     

    그림 15

     

    곱셈 노드의 역전파에서는 입력 신호를 서로 바꿔서 하류로 흘린다.

     

    [그림 15]의 결과를 보면 사과 가격의 미분은 2.2, 사과 개수의 미분은 110, 소비세의 미분은 200이다.

     

    이는 소비세와 사과 가격이 같은 양만큼 오르면 최종 금액에는 소비세가 200의 크기로, 사과 가격이 2.2 크기로 영향을 준다고 해석할 수 있다.

     

    단, 이 예에서 소비세와 사과 가격은 단위가 다르니 주의해야 한다 (소비세 1은 100%, 사과 가격 1은 1원).

     

     

     

     

     

    단순한 계층 구현하기


    '사과 쇼핑' 예를 파이썬으로 구현하자.

     

    계산 그래프의 곱셈 노드를 'MulLayer', 덧셈 노드를 'AddLayer'라는 이름으로 구현한다.

     

     

     

    모든 계층은 forward()와 backward()라는 공통의 메서드(인터페이스)를 갖도록 구현할 것이다.

     

    forward()는 순전파, backward()은 역전파를 처리한다.

     

     

    1. 곱셈 계층 'MulLayer'

     

     

     

    class MulLayer:
        def __init__(self):
            self.x = None
            self.y = None
            
        def forward(self, x, y):
            self.x = x
            self.y = y
            out = x * y
            
            return out
        
        def backward(self, dout):
            dx = dout * self.y # dout은 상류에서 넘어온 미분을 의미
            dy = dout * self.x
            
            return dx, dy

    이를 사용해서 앞서 본 '사과 쇼핑'을 구현해보자.

     

    사과 2개 구입

    apple = 100
    apple_num = 2
    tax = 1.1
    
    # 계층들
    mul_apple_layer = MulLayer()
    mul_tax_layer = MulLayer()
    
    # 순전파
    apple_price = mul_apple_layer.forward(apple, apple_num)
    price = mul_tax_layer.forward(apple_price, tax)
    
    print(price) # 220
    Out:    220.00000000000003

     

    # 역전파
    dprice = 1
    dapple_price, dtax = mul_tax_layer.backward(dprice)
    dapple, dapple_num = mul_tax_layer.backward(dapple_price)
    
    print(dapple, dapple_num, dtax) # 2.2 110 200
    Out:    1.2100000000000002 220.00000000000003 200

    backward()의 호출 순서는 forward() 때와는 반대이다. 또 backward()가 받는 인수는 '순전파의 출력에 대한 미분'임에 주의하자.

     

     

    2. 덧셈 계층 'AddLayer'

    class AddLayer:
        def __init__(self):
            pass
        
        def forward(self, x, y):
            out = x + y
            
            return out
        
        def backward(self, dout):
            dx = dout * 1
            dy = dout * 1
            
            return dx, dy

     

     

    이번엔 덧셈 계층과 곱셈 계층을 사용하여 사과 2개와 귤 3개를 사는 상황을 구현해보자.

    사과 2개와 귤 3개 구입

     

    apple = 100
    apple_num = 2
    orange = 150
    orange_num = 3
    tax = 1.1
    
    # 계층들
    mul_apple_layer = MulLayer()
    mul_orange_layer = MulLayer()
    add_apple_orange_layer = AddLayer()
    mul_tax_layer = MulLayer()
    
    # 순전파
    apple_price = mul_apple_layer.forward(apple, apple_num)
    orange_price = mul_orange_layer.forward(orange, orange_num)
    all_price = add_apple_orange_layer.forward(apple_price, orange_price)
    price = mul_tax_layer.forward(all_price, tax)
    
    
    # 역전파
    dprice = 1
    dall_price, dtax = mul_tax_layer.backward(dprice)
    dapple_price, dorange_price = add_apple_orange_layer.backward(dall_price)
    dorange, dorange_num = mul_orange_layer.backward(dorange_price)
    dapple, dapple_num = mul_apple_layer.backward(dapple_price)
    
    print(price)
    print(dapple_num, dapple, dorange, dorange_num, dtax)
    Out:   
    715.0000000000001
    110.00000000000001 2.2 3.3000000000000003 165.0 650

     

    댓글

dokylee's Tech Blog