SimCTG: 텍스트 생성에서 Contrastive 학습 추론에 대한 설명, 구현

개시하다


안녕하십니까?
이미지, 자연 언어, 사운드 관련 머신러닝 연구개발과 MLOps를 진행 중이다.만약 기계 학습에 관한 상의가 있다면 의 계정에 DM!
이것은 지금까지의기계 학습 보도의 총결이다.
이 글은 A Contrastive Framework for Neural Text Generation에서 제시된 SimCTG(a SIMple Contrastive frame work for neural Text Generation)를 글 생성 모델이 텍스트를 생성하는 부자연성과 원하지 않는 단어의 중복을 억제하는 방법으로 소개했다.
또한 참고논문의 실제 부호로 Enceoder Decorder 형식의 문장 생성 모델 T5에 SimCTG를 적용해 보았습니다. 다음은 설치 방법을 소개합니다.
논문의 실제 부호 중 GPT-2 설치에 적합한 것이 있으니 GPT-2를 사용하고자 하는 분들은 참고여기.하십시오.

SimCTG가 필요한 이유


최근 데코더의 모델은 입력문+과거에 생성된 단어 열에서 아래 단어 열을 예측하는 Auto-regressive의 생성이 진행되고 있다.이때 계산량의 관계로 인해 Gready Search와 Beam Search 등이 사용됩니다.
참조: 텍스트 생성 중인 decoding 기술: Greedy search, Beam search, Top-K, Top-p
그러나 이러한 검색 방법을 사용하면 같은 단어를 반복하는 등의 문제가 발생할 수 있다.따라서 huggingface의 문장 생성 함수generete에는 no_repeat_ngram_sizerepetition_penalty의 매개 변수가 있는데 이 매개 변수에 따라 중복을 억제할 수 있지만 매개 변수에 따라 같은 단어열이 전혀 나타나지 않아 부자연스러운 생성이 발생할 수 있다.
다음 그림(논문에서 Fig.1)에서 (a)는 GPT-2로 생성된 각 영패의 특징 벡터(Transformer 최종층)의 여현 유사도 행렬이다.보시다시피 문장의 영패 간의 유사도는 0.95 이상으로 서로의 특징 방향이 매우 가깝다는 것을 알 수 있다.이것은 서로 다른 단계에서 영패를 반복적으로 생성하는 원인 중의 하나로 여겨진다.논문에서 이러한 특징의 벡터 편파는 각방향 이성이라고 불린다.
Fig.1
이상적으로 생성된 영패의 특징을 식별할 수 있도록 위의 그림(b)에서 보듯이 영패 간의 유사도를 낮춰야 한다.각 영패의 특징 벡터가 편파적이지 않아야 한다는 것이다.(각 방향의 동성이어야 한다.)
따라서 가장 유사한 추정(MLS)에서 언어 모델을 배우고 가장 가능성이 있는 서열을 디코딩하는 전통적인 방법에 대해 각 동성 특징의 벡터를 식별하고 학습하는 SimCTG 손실을 촉진하는 한편, SimCTG 학습의 새로운 디코딩 방법을 보충하는 Contrastive Search(대조 탐색)를 제시했다.

SimCTG


SimCTG는 최대 유사 추정 손실(교차 엔트로피 손실)과 대비 손실(Constrastive Looss)으로 구성된다.
만약 단어가 x로 열거된다면 가장 유사한 손실은 다음과 같다.이전 단어 열 (x {i}) 에서 생성된 단어는 x {i}입니다.i의 확률 분포p{theeta}(x) 최대화 손실\mathcal{L]{MLE}입니다.
\begin{equation}
\mathcal{L}_{MLE} = -\frac{1}{|x|}\sum_{i=1}^{|x|}\log p_{\theta} (x_i | x_{ < i})
\tag{1}
\end{equation}
손실에 대비하여\mathcal{L}{CL} 유사도 함수 s를 사용하여 다음 공식을 사용합니다.이것은 같은 단어의 특징 벡터의 유사도를 높이고 서로 다른 단어의 특징 벡터의 유사도 손실을 줄이는 것이다.
\begin{equation}
\mathcal{L}_{CL} =\frac{1}{|x|\times (|x| - 1)}\sum_{i=1}^{|x|}\sum_{j=1, j\neq i}^{|x|}
\max\{ 0,\rho - s(h_{x_i}, h_{x_i}) + s(h_{x_i}, h_{x_j})\}
\tag{2}
\end{equation}
여유\rho\in[1,1]은 하이퍼매개변수, h{xi}는 토큰(xi)의 특징 벡터입니다.유사도 함수
\begin{equation}
s(h_{x_i}, h_{x_j}) =\frac{h_{x_i}^Th_{x_j}}{||h_{x_i}||\cdot || h_{x_j} ||}
\tag{3}
\end{equation}
자, 계산해.
상기 손실 사용, SimCTG 손실
\begin{equation}
\mathcal{L}_{CTG} =\mathcal{L}_{MLE} +\mathcal{L}_{CL}
\tag{4}
\end{equation}
.
SimCTG 손실을 사용하여 학습하면 같은 영패의 특징 벡터는 더욱 가까워지고 서로 다른 영패 사이의 특징 벡터는 멀어진다.

Contrastive Search


손실을 대조한 심CTG 학습 모델을 사용해 다음 단어를 예측하면 이전 단어와 열이 다른 단어일 가능성이 크다.이러한 지식을 흡수해 기존 방법과 마찬가지로 탐색 확률이 높은 단어 외에도 이전 단어와 열의 싱크로율이 낮은 단어를 찾는 Contrastive Search를 제시했다.
예측 확률이 높은 k개의 단어 집합을 V^k로 설정하면 Contastive Search를 통해 단어 선택
\begin{equation}
x_t =\mathrm{argmax}_{v\in V^k}\{ (1 -\alpha)\times p_{\theta}(v | x_{ < t})
-\alpha\times (\mathrm{max}\{ s(h_v, h_{x_j}) : 1\leq j\leq t - 1\})\}
\tag{5}
\end{equation}
진행합니다.

결실


논문의 테이블.보다논문에서 보듯이 SimCTG에서 Contrastive Search를 가장 잘 사용합니다.다른 한편에서는 두 가지 기법을 쓸 필요가 있음을 알 수 있다.

T5의 SimCTG 학습 설치


이번에는 [일본어 모드 첨부] 2021년에 자연어 처리된 사람에게 추천하고 싶은 사전 학습 완료 모드.에서 제공한 T5 이전 학습 사례코드를 바탕으로 확장했다.기사 제목 (기사 요약) 의 링크에서 Google Colab의 트레이닝 코드를 볼 수 있습니다.또한 본 글은 간단하게 보기 위해 수정된 SimCTG의 중요한 부분만 열거하였다.
우선 유사도 계산의 부분이다.여기서 T5 디코더의 최종 레이어는 피쳐 벡터로 사용됩니다.분할, 단어 열수, 특징 벡터로 구성된 행렬이기 때문에 각 단어의 특징 벡터의 유사도를 계산하고 분할수를 계산한다×단어 열수×단어 열을 출력하는 유사도 행렬입니다.
def t5_cosine_simirality(outputs):
    last_hidden_state = outputs.decoder_hidden_states[-1] # torch.Size([3, 64, 768])
    norm_rep = last_hidden_state / last_hidden_state.norm(dim=2, keepdim=True)
    cosine_scores = torch.matmul(norm_rep, norm_rep.transpose(1,2)) # torch.Size([3, 64, 64])
    return cosine_scores
미리 배운 모델의 유사도 행렬은 다음과 같다.(매우 길기 때문에 앞의 5행 5열만 있다.)
실제로 대각 부분 이외의 유사도도 높다는 것을 알 수 있다.
[0.9999, 0.9807, 0.9654, 0.9587, 0.9528, 
[0.9807, 0.9999, 0.9805, 0.9755, 0.9687,
[0.9654, 0.9805, 0.9999, 0.9816, 0.9712,
[0.9587, 0.9755, 0.9816, 0.9999, 0.9878, 
[0.9528, 0.9687, 0.9712, 0.9878, 0.9999,
다음은 손실을 비교하는 부분이다.이것은 코드와 같다.논문 코드를 누르는 스타!
def compute_contrastive_loss(score_matrix, margin):
    '''
        margin: predefined margin to push similarity score away
        score_matrix: bsz x seqlen x seqlen; cosine similarity matrix
        input_ids: bsz x seqlen
    '''
    bsz, seqlen, _ = score_matrix.size()
    gold_score = torch.diagonal(score_matrix, offset=0, dim1=1, dim2=2) #対角成分を取り出す bsz x seqlen
    gold_score = torch.unsqueeze(gold_score, -1) 
    assert gold_score.size() == torch.Size([bsz, seqlen, 1])

    difference_matrix = gold_score - score_matrix
    assert difference_matrix.size() == torch.Size([bsz, seqlen, seqlen])
    loss_matrix = margin - difference_matrix # bsz x seqlen x seqlen
    loss_matrix = torch.nn.functional.relu(loss_matrix)
    cl_loss = torch.mean(loss_matrix)
    return cl_loss
다음은 학습 부분입니다.Pytorch Lightning을 사용하고 있기 때문에 참고할 수 있는 이동 학습 코드를 사용하십시오. 그 모델 부분의 확장입니다.
T5의 숨겨진 레이어를 내보내기 위해 설정output_hidden_states=True 되었습니다._step 부분에서 손실을 추가로 계산하고 있다.
class T5FineTuner(pl.LightningModule):
    def __init__(self, hparams):
        super().__init__()
        # 最新版のplでは、hparamsのupdate方法が変更された。
        self.save_hyperparameters(hparams)
        # 事前学習済みモデルの読み込み
        self.model = T5ForConditionalGeneration.from_pretrained(hparams.model_name_or_path)
        # トークナイザーの読み込み
        self.tokenizer = T5Tokenizer.from_pretrained(hparams.tokenizer_name_or_path, is_fast=True)

    def forward(self, input_ids, attention_mask=None, decoder_input_ids=None, 
                decoder_attention_mask=None, labels=None):
        """順伝搬"""
        return self.model(
            input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            labels=labels,
            output_hidden_states=True # decoderの隠れ層を出力するためTrueに設定。
        )

    def _step(self, batch):
        """ロス計算"""
        labels = batch["target_ids"]
        labels[labels[:, :] == self.tokenizer.pad_token_id] = -100

        outputs = self(
            input_ids=batch["source_ids"],
            attention_mask=batch["source_mask"],
            decoder_attention_mask=batch['target_mask'],
            labels=labels,
        )

        mle_loss = outputs.loss
        cosine_scores = t5_cosine_simirality(outputs) # 類似度計算
        margin = 0.5 # margin
        cl_loss = compute_contrastive_loss(cosine_scores, margin) # 対照損失計算

        loss = mle_loss + cl_loss # SimCTG 損失
        loss = loss.mean()
        return loss
상기 변경된 코드를 통해 학습한 결과는 다음과 같다.(마찬가지로 길기 때문에 앞의 5행 5열만 있습니다.)
대각 성분의 유사도가 높아 서로 다른 영패 간의 유사도, 즉 대각 성분을 제외하고는 유사도가 작아진 것을 발견했다.
[0.9999, 0.2982, -0.5032, -0.3650, -0.5413,
[0.2982, 0.9999, 0.2020, 0.2578, 0.2829,
[-0.5032, 0.2020, 0.9999, 0.6498, 0.6361,
[-0.3650, 0.2578, 0.6498, 1.0, 0.6117, 0.7059,
[-0.5413, 0.2829, 0.6361, 0.6117, 0.9999

T5의 Contrastive Search 설치


우선 추론의 대략적인 절차다.중요한 부분은 앞의 단어열에서 다음 단어ContrastiveDecodingOneStepFast의 함수를 예측한 다음에 설명하는 것이다.
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
input_ids = next(iter(data_loader))["source_ids"] # データローダから読み込み
model = T5ForConditionalGeneration.from_pretrained(T5_MODEL_DIR) # T5_MODEL_DIRに学習したモデルの重みなどが入っていると仮定
tokenizer = T5Tokenizer.from_pretrained(T5_MODEL_DIR, is_fast=True) # Tokenizer

batch_size, seqlen = input_ids.size()

generated = [[] for _ in range(batch_size)] # バッチごとの生成文格納するリストです。
is_eos = [False for _ in range(batch_size)] # バッチごとの文章生成終了(EOS)判定リストです。
# past_key_valuesをモデルの入力とすることで、モデルの入力を一つ前の単語のみにできる(huggingfaceの実装のドキュメント見てください。)
past_key_values = None 
last_hidden_states = None
logits = None

decoding_len = 512 # 文章の最大長さ
beam_width = 3 # 実装では、ビームサーチの結果にContrastive Searchを行なっていた。
alpha = 0.5 # (1 - a) * 確率最大 + a * 類似度

input_ids.to(device)
model.eval()
for step in range(decoding_len):
    next_ids, past_key_values, last_hidden_states, logits = ContrastiveDecodingOneStepFast(
        model,
        input_ids,
        beam_width,
        alpha,
        past_key_values,
        last_hidden_states,
        tokenizer,
        logits,
        device,
        first_step=step == 0, #最初のステップのみ扱い違う
    )
    tokens = next_ids.squeeze(dim=-1).tolist()
    for idx, t in enumerate(tokens): #バッチ数分
        # EOSの場合スキップ
        if is_eos[idx]:
            continue
        if t == tokenizer.eos_token_id:
            is_eos[idx] = True
            continue
        generated[idx].append(t)
다음은 ContrastiveDecodingOneStepFast의 설치입니다.논문의 실현은 GPT-2판이지만 T5는 인코더·디코더 모델이기 때문에 디코더의 입력이 필요하다.또 GPU를 사용하다가 메모리 오류로 변해 그 처리를 추가했다.논문 코드.거의 논문의 실장과 같다.
def ContrastiveDecodingOneStepFast(
    model, 
    ids, 
    beam_width, 
    alpha, 
    past_key_values,
    last_hidden_states,
    vocab,
    logit_for_next_step,
    device,
    first_step=False,
    
    ):
    # input_ids: [B, S]
    model.eval()
    if first_step:
        # T5では、最初のステップのみpad tokenをデコーダの入力とする。
        with torch.no_grad():
            bsz, _ = ids.size()
            ids = ids.to(device)
            decoder_inputs = torch.tensor([[0] for _ in range(bsz)]).to(device) # pad_token_idでstart
            output = model(
                input_ids=ids, 
                decoder_input_ids=decoder_inputs,
                past_key_values=past_key_values,
                use_cache=True,
                output_hidden_states=True
            )
            del decoder_inputs
        past_key_values = output.past_key_values
        last_hidden_states = output.decoder_hidden_states[-1].cpu()    # [B, S, E]
        logit_for_next_step = output.logits[:, -1, :].cpu()    # [B, V]
    bsz, seqlen, embed_dim = last_hidden_states.size()
    p = random.uniform(0, 1)

    next_probs = F.softmax(logit_for_next_step, dim=-1)

    # 最大確率のk候補を取得
    _, top_k_ids = torch.topk(logit_for_next_step, dim=-1, k=beam_width) # [B, V:38000くらい] -> [B, K:beam_width]
    top_k_probs = torch.gather(next_probs, dim=1, index=top_k_ids) # 候補の確率を取得 [B, K] 
    
    # モデルの入力のpast_keyをバッチx候補のタプルに修正
    past_key_values = enlarge_past_key_values(past_key_values, beam_width)

    # 次の単語を予測
    input_ids = top_k_ids.view(-1, 1).to(device) # [B*K , 1]
    with torch.no_grad():
        output = model(
            input_ids=input_ids, 
            decoder_input_ids=input_ids,
            past_key_values=past_key_values,
            output_hidden_states=True,
            use_cache=True,
        )
    past_key_values = output.past_key_values
    logits = output.logits[:, -1, :].cpu()    # [B*K, V]
    next_hidden = output.decoder_hidden_states[-1].cpu()    # [B*K, 1, E]
    context_hidden = last_hidden_states.unsqueeze(1).expand(-1, beam_width, -1, -1).reshape(bsz*beam_width, seqlen, embed_dim)    # [B*K, S, E]

    # バッチごとの最大スコアの単語を選択
    # 実際の式の部分を計算 (下の関数)
    selected_idx = ranking_fast(
        context_hidden, 
        next_hidden, 
        top_k_probs,
        alpha,
        beam_width,
    )  # [B]

    # 次のステップの準備
    next_id = top_k_ids[range(len(top_k_ids)), selected_idx].unsqueeze(-1)    # [B, 1]
    next_hidden = torch.stack(torch.split(next_hidden.squeeze(dim=1), beam_width))    # [B, K, E]
    next_hidden = next_hidden[range(bsz), selected_idx, :]    # [B, E]
    last_hidden_states = torch.cat([last_hidden_states, next_hidden.unsqueeze(1)], dim=1)    # [B, S, E]
    past_key_values = select_past_key_values(past_key_values, beam_width, selected_idx)
    logits = torch.stack(torch.split(logits, beam_width))[range(bsz), selected_idx, :]    # [B, V]
    
    # GPUのメモリ解放
    if device != torch.device("cpu"):
        del model, input_ids, top_k_ids, top_k_probs
        torch.cuda.empty_cache()
    return next_id, past_key_values, last_hidden_states, logits 


def ranking_fast(context_hidden, next_hidden, next_top_k_probs, alpha, beam_width):
    '''バッチごとの最大スコアの単語を選択
        context_hidden: bsz*beam x seqlen x embed_dim
        next_hidden: bsz*beam x 1 x embed_dim
        next_top_k_probs: bsz x beam
    '''
    _, context_len, embed_dim = context_hidden.size()
    norm_context_hidden = context_hidden / context_hidden.norm(dim=2, keepdim=True)
    norm_next_hidden = next_hidden / next_hidden.norm(dim=2, keepdim=True)
    cosine_matrix = torch.matmul(norm_context_hidden, norm_next_hidden.transpose(1,2)).squeeze(-1)    # [B*K, S]
    scores, _ = torch.max(cosine_matrix, dim=-1)    # [B*K]
    next_top_k_probs = next_top_k_probs.view(-1)    # [B*K]
    # 式の部分
    scores = (1.0 - alpha) * next_top_k_probs - alpha * scores #
    scores = torch.stack(torch.split(scores, beam_width)) # バッチごとに戻す [B, K]
    selected_idx = scores.max(dim=-1)[1] # [B]
    return selected_idx

def enlarge_past_key_values(past_key_values, beam_width):
    # モデルの入力のpast_keyをバッチx候補のタプルに修正
    # from [B, num_head, seq_len, esz] to [B*K, num_head, seq_len, esz]
    new_key_values = []
    for layer in past_key_values:
        items = []
        for item in layer:
            # item is the key and value matrix
            bsz, num_head, seq_len, esz = item.size()
            item = item.unsqueeze(1).expand(-1, beam_width, -1, -1, -1).reshape(bsz*beam_width, num_head, seq_len, esz)    # [bsz*beam, num_head, seq_len, esz]
            items.append(item)
        new_key_values.append(items)
    return new_key_values

def select_past_key_values(past_key_values, beam_width, selected_idx):
    '''バッチx候補から最大スコアのpast_keyを選択
    select_idx: [B]
    '''
    new_key_values = []
    for layer in past_key_values:
        items = []
        for item in layer:
            bsz_and_beam, num_head, seq_len, esz = item.size()
            bsz = int(bsz_and_beam//beam_width)
            item = torch.stack(torch.split(item, beam_width, dim=0))    # [B, K, num_head, seq_len, esz] 
            item = item[range(bsz), selected_idx, :, :, :]   # [B, num_head, seq_len, esz]
            items.append(item)
        new_key_values.append(items)
    return new_key_values
이상은 T5의 Constrive Search 설치입니다.

총결산


나는 문장 생성의 개선 방법을 찾고 있다. 우연히 발견한 논문이지만 이것은 간단하게 실시할 수 있고 효과가 좋은 방법이라고 생각한다.재현 실험을 한 적은 없지만, 앞으로 꾸준히 사용하고자 하는 수법이다.
개인적으로 Contrastive Learning은 자신이 좋아하는 분야에 SimSiam에 관한 기사를 쓰고 있다.앞으로도 이 분야에 눈을 돌리고 싶다.

좋은 웹페이지 즐겨찾기