Transformer: Scaled Dot-Product Attention 메모

1 Scaled Dot-Product Attention



Scaled Dot-Production Attention의 Attention 함수는 Query, Key, Value를 입력으로 하는 이하의 함수이다.


그림에서 보면 다음과 같습니다.


2 코드



Tensorflow 튜토리얼에 기재된 Scaled Dot-Product Attention 메소드의 구현은 다음과 같습니다.
import tensorflow as tf

#############################################
#
# Scaled Dot Product Attention
#   Attention(Q, K, V) = softmax( Q*K.T / sqrt(d)) * V
#
def scaled_dot_product_attention(q, k, v):
    # Q * K.T
    matmul_qk = tf.matmul(q, k, transpose_b=True)

    # Q*K.T / sqrt(d)
    dk = tf.cast(tf.shape(k)[-1], tf.float32)
    scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)

    # softmax( Q*K.T / sqrt(d) )
    attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)

    # softmax( Q*K.T / sqrt(d) ) * V
    output = tf.matmul(attention_weights, v)

    return output, attention_weights


Q와 K의 전치의 내적을 계산
matmul_qk = tf.matmul(q, k, transpose_b=True)

Q와 K의 전치의 내적을 루트 dk로 나누기
scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)

softmax 계산
attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)

softmax 결과와 V의 내적 계산
output = tf.matmul(attention_weights, v)

도면 중의 Mask(opt.)는 옵션 때문에 생략하고 있다.

2.2 테스트 코드



K, V, Q 값을 지정하여 scaled_dot_product_attention 메소드를 호출하는 코드.
#############################################
#
# Test "Scaled Dot Product Attention" method
#
k = tf.constant([[10, 0, 0],
                 [ 0,10, 0],
                 [ 0, 0,10],
                 [ 0, 0,10]], dtype=tf.float32)

v = tf.constant([[    1, 0],
                 [   10, 0],
                 [  100, 5],
                 [ 1000, 6]], dtype=tf.float32)

q = tf.constant([[0, 10, 0]], dtype=tf.float32)

print('---input---')
print(k)
print(v)
print(q)

result, attention_w = scaled_dot_product_attention(q,k,v)

print('---result---')
print(attention_w)
print(result)

k, v, q를 입력하여 scaled_dot_product_attention 메소드를 호출하여 출력 결과를 표시합니다.

3 실행



테스트 코드를 실행하면 다음과 같은 결과가 된다.

(1.00000e+01 9.276602e-25)가 출력 결과이다.

4 환경





버전


파이썬
3.7.4

Tensorflow
2.3.1


5 참고





URL


Transformer 논문
htps : // 아 rぃ v. rg/아bs/1706.03762

Tensorflow 튜토리얼
htps //w w. 천식 rfぉw. rg / thoo ls / xt / tran s fur r

좋은 웹페이지 즐겨찾기