Source code for recwizard.modules.unicrs.prompt_gpt2

from typing import Tuple, Optional, Any

import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from torch.nn import functional as F
from transformers import Conv1D, GPT2Config
from transformers.file_utils import ModelOutput
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
from transformers.modeling_utils import find_pruneable_heads_and_indices, prune_conv1d_layer
from transformers.models.gpt2.modeling_gpt2 import GPT2MLP, GPT2PreTrainedModel, logger
from transformers.utils.model_parallel_utils import assert_device_map, get_device_map


# Adapted from github.com/RUCAIBox/UniCRS/blob/master/src/model_gpt2.py
# maybe we can use framework like OpenPrompt in the future.

[docs]class GPT2Attention(nn.Module):
[docs] def __init__(self, config, is_cross_attention=False): super().__init__() max_positions = config.max_position_embeddings self.register_buffer( "bias", torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view( 1, 1, max_positions, max_positions ), ) self.register_buffer("masked_bias", torch.tensor(-1e4)) self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.embed_dim // self.num_heads self.split_size = self.embed_dim if self.head_dim * self.num_heads != self.embed_dim: raise ValueError( f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})." ) self.scale_attn_weights = config.scale_attn_weights self.is_cross_attention = is_cross_attention if self.is_cross_attention: self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim) self.q_attn = Conv1D(self.embed_dim, self.embed_dim) else: self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim) self.c_proj = Conv1D(self.embed_dim, self.embed_dim) self.attn_dropout = nn.Dropout(config.attn_pdrop) self.resid_dropout = nn.Dropout(config.resid_pdrop) self.pruned_heads = set()
def prune_heads(self, heads): if len(heads) == 0: return heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads) index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)]) # Prune conv1d layers self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1) self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0) # Update hyper params self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads)) self.num_heads = self.num_heads - len(heads) self.pruned_heads = self.pruned_heads.union(heads)
[docs] def _split_heads(self, tensor, num_heads, attn_head_size): """ Splits hidden_size dim into attn_head_size and num_heads """ new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) tensor = tensor.view(*new_shape) return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
[docs] def _merge_heads(self, tensor, num_heads, attn_head_size): """ Merges attn_head_size dim and num_attn_heads dim into hidden_size """ tensor = tensor.permute(0, 2, 1, 3).contiguous() new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,) return tensor.view(new_shape)
def _attn(self, query, key, value, prompt_len=0, attention_mask=None, head_mask=None): attn_weights = torch.matmul(query, key.transpose(-1, -2)) # (batch_size, head, query_len, key_len) if self.scale_attn_weights: attn_weights = attn_weights / (float(value.size(-1)) ** 0.5) if not self.is_cross_attention: # if only "normal" attention layer implements causal mask query_length, key_length = query.size(-2), key.size(-2) if prompt_len > 0: key_length -= prompt_len # if query_length == key_length: # causal_mask = self.bias[:, :, key_length - query_length: key_length, :key_length].bool() # else: # causal_mask = self.bias[:, :, key_length - query_length: key_length, # key_length - query_length: key_length].bool() # left_mask_shape = list(causal_mask.shape[:-1]) + [key_length - query_length] # left_mask = causal_mask.new_ones(left_mask_shape) # causal_mask = torch.cat([left_mask, causal_mask], dim=-1) causal_mask = self.bias[:, :, key_length - query_length: key_length, :key_length].bool() if prompt_len > 0: left_mask_shape = list(causal_mask.shape[:-1]) + [prompt_len] left_mask = causal_mask.new_ones(left_mask_shape) causal_mask = torch.cat([left_mask, causal_mask], dim=-1) attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype)) if attention_mask is not None: # Apply the attention mask attn_weights = attn_weights + attention_mask attn_weights = nn.Softmax(dim=-1)(attn_weights) attn_weights = self.attn_dropout(attn_weights) # Mask heads if we want to if head_mask is not None: attn_weights = attn_weights * head_mask attn_output = torch.matmul(attn_weights, value) return attn_output, attn_weights
[docs] def forward( self, hidden_states, layer_past=None, prompt_embeds=None, # (2, batch_size, head_num, left_prompt_len, head_size) attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, use_cache=False, output_attentions=False, ): if encoder_hidden_states is not None: if not hasattr(self, "q_attn"): raise ValueError( "If class is used as cross attention, the weights `q_attn` have to be defined. " "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." ) query = self.q_attn(hidden_states) key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) attention_mask = encoder_attention_mask else: query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) query = self._split_heads(query, self.num_heads, self.head_dim) key = self._split_heads(key, self.num_heads, self.head_dim) value = self._split_heads(value, self.num_heads, self.head_dim) if layer_past is not None: past_key, past_value = layer_past key = torch.cat([past_key, key], dim=-2) value = torch.cat([past_value, value], dim=-2) elif prompt_embeds is not None: key = torch.cat([prompt_embeds[0], key], dim=-2) value = torch.cat([prompt_embeds[1], value], dim=-2) if use_cache is True: present = (key, value) else: present = None prompt_len = 0 if prompt_embeds is not None: prompt_len = prompt_embeds.shape[-2] attn_output, attn_weights = self._attn( query, key, value, prompt_len, attention_mask, head_mask ) attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) attn_output = self.c_proj(attn_output) attn_output = self.resid_dropout(attn_output) outputs = (attn_output, present) if output_attentions: outputs += (attn_weights,) return outputs # a, present, (attentions)
[docs]class GPT2Block(nn.Module):
[docs] def __init__(self, config): super().__init__() hidden_size = config.hidden_size inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.attn = GPT2Attention(config) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) if config.add_cross_attention: self.crossattention = GPT2Attention(config, is_cross_attention=True) self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.mlp = GPT2MLP(inner_dim, config)
[docs] def forward( self, hidden_states, layer_past=None, prompt_embeds=None, # (2, batch_size, head_num, prefix_len, head_size) attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, use_cache=False, output_attentions=False, ): residual = hidden_states hidden_states = self.ln_1(hidden_states) attn_outputs = self.attn( hidden_states, layer_past=layer_past, prompt_embeds=prompt_embeds, attention_mask=attention_mask, head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, ) attn_output = attn_outputs[0] # output_attn: a, present, (attentions) outputs = attn_outputs[1:] # residual connection hidden_states = attn_output + residual if encoder_hidden_states is not None: # add one self-attention block for cross-attention if not hasattr(self, "crossattention"): raise ValueError( f"If `encoder_hidden_states` are passed, {self} has to be instantiated with " "cross-attention layers by setting `config.add_cross_attention=True`" ) residual = hidden_states hidden_states = self.ln_cross_attn(hidden_states) cross_attn_outputs = self.crossattention( hidden_states, attention_mask=attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, output_attentions=output_attentions, ) attn_output = cross_attn_outputs[0] # residual connection hidden_states = residual + attn_output outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights residual = hidden_states hidden_states = self.ln_2(hidden_states) feed_forward_hidden_states = self.mlp(hidden_states) # residual connection hidden_states = residual + feed_forward_hidden_states if use_cache: outputs = (hidden_states,) + outputs else: outputs = (hidden_states,) + outputs[1:] return outputs # hidden_states, present, (attentions, cross_attentions)
[docs]class GPT2Model(GPT2PreTrainedModel): _keys_to_ignore_on_load_missing = ["attn.masked_bias"]
[docs] def __init__(self, config): super().__init__(config) self.embed_dim = config.hidden_size self.wte = nn.Embedding(config.vocab_size, self.embed_dim) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.drop = nn.Dropout(config.embd_pdrop) self.h = nn.ModuleList([GPT2Block(config) for _ in range(config.num_hidden_layers)]) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) self.init_weights() # Model parallel self.model_parallel = False self.device_map = None self.gradient_checkpointing = False
def parallelize(self, device_map=None): # Check validity of device_map self.device_map = ( get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map ) assert_device_map(self.device_map, len(self.h)) self.model_parallel = True self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys())) self.last_device = "cuda:" + str(max(self.device_map.keys())) self.wte = self.wte.to(self.first_device) self.wpe = self.wpe.to(self.first_device) # Load onto devices for k, v in self.device_map.items(): for block in v: cuda_device = "cuda:" + str(k) self.h[block] = self.h[block].to(cuda_device) # ln_f to last self.ln_f = self.ln_f.to(self.last_device) def deparallelize(self): self.model_parallel = False self.device_map = None self.first_device = "cpu" self.last_device = "cpu" self.wte = self.wte.to("cpu") self.wpe = self.wpe.to("cpu") for index in range(len(self.h)): self.h[index] = self.h[index].to("cpu") self.ln_f = self.ln_f.to("cpu") torch.cuda.empty_cache()
[docs] def get_input_embeddings(self): return self.wte
[docs] def set_input_embeddings(self, new_embeddings): self.wte = new_embeddings
[docs] def _prune_heads(self, heads_to_prune): """ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} """ for layer, heads in heads_to_prune.items(): self.h[layer].attn.prune_heads(heads)
[docs] def forward( self, input_ids=None, past_key_values=None, prompt_embeds=None, # (layer_num, 2, batch_size, head_num, prompt_len, head_dim) attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, encoder_hidden_states=None, encoder_attention_mask=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) batch_size = input_ids.shape[0] elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] batch_size = inputs_embeds.shape[0] else: raise ValueError("You have to specify either input_ids or inputs_embeds") device = input_ids.device if input_ids is not None else inputs_embeds.device if token_type_ids is not None: token_type_ids = token_type_ids.view(-1, input_shape[-1]) if position_ids is not None: position_ids = position_ids.view(-1, input_shape[-1]) # GPT2Attention mask. if attention_mask is not None: if batch_size <= 0: raise ValueError("batch_size has to be defined and > 0") attention_mask = attention_mask.view(batch_size, -1) if prompt_embeds is not None: prompt_attention_mask = prompt_embeds.new_ones((batch_size, prompt_embeds.shape[-2])) attention_mask = torch.cat([prompt_attention_mask, attention_mask], dim=-1) # We create a 3D attention mask from a 2D tensor mask. # Sizes are [batch_size, 1, 1, to_seq_length] # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] # this attention mask is more simple than the triangular masking of causal attention # used in OpenAI GPT, we just need to prepare the broadcast dimension here. attention_mask = attention_mask[:, None, None, :] # Since attention_mask is 1.0 for positions we want to attend and 0.0 for # masked positions, this operation will create a tensor which is 0.0 for # positions we want to attend and -10000.0 for masked positions. # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility attention_mask = (1.0 - attention_mask) * -10000.0 if past_key_values is None: past_length = 0 past_key_values = tuple([None] * len(self.h)) else: past_length = past_key_values[0][0].size(-2) if position_ids is None: position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) # If a 2D ou 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] if self.config.add_cross_attention and encoder_hidden_states is not None: encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) if encoder_attention_mask is None: encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) else: encoder_attention_mask = None # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x n_heads x N x N # head_mask has shape n_layer x batch x n_heads x N x N head_mask = self.get_head_mask(head_mask, self.config.n_layer) if inputs_embeds is None: inputs_embeds = self.wte(input_ids) position_embeds = self.wpe(position_ids) hidden_states = inputs_embeds + position_embeds if token_type_ids is not None: token_type_embeds = self.wte(token_type_ids) hidden_states = hidden_states + token_type_embeds hidden_states = self.drop(hidden_states) output_shape = input_shape + (hidden_states.size(-1),) presents = () if use_cache else None all_self_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None all_hidden_states = () if output_hidden_states else None for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): # Model parallel if self.model_parallel: torch.cuda.set_device(hidden_states.device) # Ensure layer_past is on same device as hidden_states (might not be correct) if layer_past is not None: layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) # Ensure that attention_mask is always on the same device as hidden_states if attention_mask is not None: attention_mask = attention_mask.to(hidden_states.device) if isinstance(head_mask, torch.Tensor): head_mask = head_mask.to(hidden_states.device) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: if use_cache: logger.warning( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False def create_custom_forward(module): def custom_forward(*inputs): # None for past_key_value return module(*inputs, use_cache, output_attentions) return custom_forward outputs = torch.utils.checkpoint.checkpoint( create_custom_forward(block), hidden_states, None, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask, ) else: outputs = block( hidden_states, layer_past=layer_past, prompt_embeds=prompt_embeds[i] if prompt_embeds is not None else None, attention_mask=attention_mask, head_mask=head_mask[i], encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, use_cache=use_cache, output_attentions=output_attentions, ) hidden_states = outputs[0] if use_cache is True: presents = presents + (outputs[1],) if output_attentions: all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) if self.config.add_cross_attention: all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) # Model Parallel: If it's the last layer for that device, put things on the next device if self.model_parallel: for k, v in self.device_map.items(): if i == v[-1] and "cuda:" + str(k) != self.last_device: hidden_states = hidden_states.to("cuda:" + str(k + 1)) hidden_states = self.ln_f(hidden_states) hidden_states = hidden_states.view(*output_shape) # Add last hidden state if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: return tuple( v for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=presents, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, )
[docs]class PromptGPT2LMHead(GPT2PreTrainedModel): _keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"] config_class = GPT2Config
[docs] def __init__(self, config): super().__init__(config) self.transformer = GPT2Model(config) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) self.init_weights() # Model parallel self.model_parallel = False self.device_map = None
def parallelize(self, device_map=None): self.device_map = ( get_device_map(len(self.transformer.h), range(torch.cuda.device_count())) if device_map is None else device_map ) assert_device_map(self.device_map, len(self.transformer.h)) self.transformer.parallelize(self.device_map) self.lm_head = self.lm_head.to(self.transformer.first_device) self.model_parallel = True def deparallelize(self): self.transformer.deparallelize() self.transformer = self.transformer.to("cpu") self.lm_head = self.lm_head.to("cpu") self.model_parallel = False torch.cuda.empty_cache()
[docs] def get_output_embeddings(self): return self.lm_head
def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def prepare_inputs_for_generation( self, input_ids, past=None, prompt_embeds=None, **kwargs ): token_type_ids = kwargs.get("token_type_ids", None) # only last token for inputs_ids if past is defined in kwargs if past: input_ids = input_ids[:, -1].unsqueeze(-1) if token_type_ids is not None: token_type_ids = token_type_ids[:, -1].unsqueeze(-1) attention_mask = kwargs.get("attention_mask", None) position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past: position_ids = position_ids[:, -1].unsqueeze(-1) else: position_ids = None return { "input_ids": input_ids, "past_key_values": past, "prompt_embeds": prompt_embeds, "use_cache": kwargs.get("use_cache"), "position_ids": position_ids, "attention_mask": attention_mask, "token_type_ids": token_type_ids, "conv": True }
[docs] def forward( self, input_ids=None, past_key_values=None, prompt_embeds=None, # (layer_num, 2, batch_size, head_num, prompt_len, head_dim) attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, encoder_hidden_states=None, encoder_attention_mask=None, use_cache=None, output_attentions=None, output_hidden_states=None, conv=True, labels=None, return_dict=True, )-> CausalLMOutputWithCrossAttentions: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` """ transformer_outputs = self.transformer( input_ids, past_key_values=past_key_values, prompt_embeds=prompt_embeds, # (layer_num, 2, batch_size, head_num, prompt_len, head_dim) attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) logits = transformer_outputs[0] loss = None if conv: # Set device for model parallelism if self.model_parallel: torch.cuda.set_device(self.transformer.first_device) logits = logits.to(self.lm_head.weight.device) logits = self.lm_head(logits) if labels is not None: # move labels to correct device to enable model parallelism labels = labels.to(logits.device) # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss(ignore_index=self.config.pad_token_id) loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) return CausalLMOutputWithCrossAttentions( loss=loss, logits=logits, past_key_values=transformer_outputs.past_key_values, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, cross_attentions=transformer_outputs.cross_attentions, )
@staticmethod def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor, **kwargs) -> Tuple[Tuple[Any, ...], ...]: return tuple( tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) for layer_past in past )