import logging
from transformers import AutoConfig, AutoTokenizer, T5ForConditionalGeneration, T5Config
from experimaestro import Param, LightweightTask
from typing import Optional, List, NamedTuple
import torch
from torch import nn
from xpmir.learning import ModuleInitOptions, ModuleInitMode
from xpmir.letor.records import TokenizedTexts
from xpmir.distributed import DistributableModel
from . import IdentifierGenerator, StepwiseGenerator
class GeneratorForwardOutput(NamedTuple):
"""The forward output of the generative retrieval"""
logits: torch.tensor
past_key_values: Optional[torch.tensor] = None
class CustomOutputT5(T5ForConditionalGeneration):
"""T5-based identifier generation
The class modifies T5 to use a custom vocabulary in the decoder
"""
def __init__(self, config: T5Config, decoder_outdim):
# not including the eos and pad
self.decoder_outdim = decoder_outdim
# modification of the config according to our needs
config.pad_token_id = self.decoder_outdim + 1
config.decoder_start_token_id = self.decoder_outdim + 1
config.eos_token_id = self.decoder_outdim
# save
self.config = config
super().__init__(self.config)
# Modify LM head
self.lm_head = nn.Linear(
self.lm_head.in_features, self.decoder_outdim + 1, bias=False
)
# Make the input embedding has the name of encoder.embed_tokens
encoder_embeddings = nn.Embedding(self.config.vocab_size, self.config.d_model)
self.get_encoder().set_input_embeddings(encoder_embeddings)
self.config.vocab_size = self.decoder_outdim + 1
# Modify the decoder vocabulary
decoder_embeddings = nn.Embedding(
self.decoder_outdim + 2, self.config.d_model, padding_idx=decoder_outdim + 1
)
self.get_decoder().set_input_embeddings(decoder_embeddings)
def forward(self, **kwargs):
return super().forward(**kwargs)
class T5StepwiseGenerator(StepwiseGenerator):
def __init__(self, id_generator: IdentifierGenerator):
super().__init__()
# The identifier to use to generate the next step's token
self.id_generator = id_generator
def init(self, texts: List[str]):
"""Initialize some inner states for further iterations, and return
the initial decoder input tokens"""
self.encoder_output, self.attention_mask = self.id_generator.encode(texts)
self.past_key_values = None
def step(self, decoder_input_tokens) -> torch.Tensor: # input shape [bs]
"""Returns the distribution over next tokens (BxV) by performing a
stepwise iteration"""
forward_output: GeneratorForwardOutput = self.id_generator(
self.attention_mask,
self.encoder_output,
decoder_input_tokens,
past_key_values=self.past_key_values,
)
self.past_key_values = forward_output.past_key_values
return forward_output.logits
[docs]class T5IdentifierGenerator(IdentifierGenerator, DistributableModel):
"""generate the id of the token based on t5-based models"""
hf_id: Param[str]
"""The HuggingFace identifier (to configure the model)"""
decoder_outdim: Param[int] = 10
"""The decoder output dimension for the t5 model, use it to
rebuild the lm_head and the decoder embedding, this number
doesn't include the pad token and the eos token
"""
def stepwise_iterator(self) -> StepwiseGenerator:
return T5StepwiseGenerator(self)
def __initialize__(self, options: ModuleInitOptions):
assert options.mode != ModuleInitMode.RANDOM, "Random mode not handled (yet)"
super().__initialize__(options)
# Easy and hacky way to get the device
self._dummy_params = nn.Parameter(torch.Tensor())
self.tokenizer = AutoTokenizer.from_pretrained(self.hf_id, use_fast=True)
self.config = AutoConfig.from_pretrained(self.hf_id)
self.model = CustomOutputT5(self.config, self.decoder_outdim)
self.pad_token_id = self.model.config.pad_token_id
self.decoder_start_token_id = self.model.config.decoder_start_token_id
self.eos_token_id = self.model.config.eos_token_id
@property
def device(self):
return self._dummy_params.device
def batch_tokenize(
self,
texts: List[str],
batch_first=True,
maxlen=None,
mask=False,
) -> TokenizedTexts:
"""Tokenize the input text"""
if maxlen is None:
maxlen = self.tokenizer.model_max_length
else:
maxlen = min(maxlen, self.tokenizer.model_max_length)
assert batch_first, "Batch first is the only option"
r = self.tokenizer(
list(texts),
max_length=maxlen,
truncation=True,
padding=True,
return_tensors="pt",
return_length=True,
return_attention_mask=mask,
)
return TokenizedTexts(
None,
r["input_ids"].to(self.device),
r.get("length", None),
r["attention_mask"].to(self.device) if mask else None,
r.get("token_type_ids", None), # if r["token_type_ids"] else None
)
def encode(self, texts: List[str]):
"""Returns the encoder_output and the input mask for the given text,
which could accelerate the autoregressive generation procedure"""
encoder = self.model.get_encoder()
tokenized = self.batch_tokenize(texts, maxlen=512, mask=True)
encoder_output = encoder(
tokenized.ids,
attention_mask=tokenized.mask,
return_dict=True,
)
return encoder_output, tokenized.mask
def forward(
self,
encoder_attention_mask, # shape [bs, seq] with 0 or 1
encoder_outputs,
decoder_input_ids=None, # if given, shape [bs]
past_key_values=None,
):
"""Get the logits from the decoder"""
bs = encoder_outputs.last_hidden_state.shape[0]
if past_key_values is None:
decoder_input_ids = (
torch.ones((bs, 1), dtype=torch.long).to(self.device)
* self.decoder_start_token_id
)
else:
if decoder_input_ids is None:
raise ValueError("decoder_input_ids of the previous layer is not given")
else:
decoder_input_ids = decoder_input_ids.unsqueeze(-1)
# Do a forward pass to get the next token
# returns three not None values:
# past_key_values, last_hidden_state, encoder_last_hidden_state
decoder_output = self.model(
decoder_input_ids=decoder_input_ids,
encoder_outputs=encoder_outputs,
attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
use_cache=True,
return_dict=True,
)
logits = decoder_output.logits[:, -1, :] # shape [bs, decoder_outdim+1]
return GeneratorForwardOutput(
logits=logits, past_key_values=decoder_output.past_key_values
)
def distribute_models(self, update):
self.model = update(self.model)
[docs]class LoadFromT5(LightweightTask):
"""Load parameters from a T5 model"""
t5_model: Param[T5IdentifierGenerator]
"""the target"""
def execute(self):
self.t5_model.initialize(None)
# Load from checkpoint
logging.info("Loading hugginface T5 from checkpoint %s", self.t5_model.hf_id)
# Load the pre-trained model
t5_model = T5ForConditionalGeneration.from_pretrained(self.t5_model.hf_id)
# Change the state_dict for the lm_head the decoder embedding
state_dict = t5_model.state_dict()
del state_dict["lm_head.weight"]
# use random initialized t5 decoder
decoder_key_names = [name for name in state_dict.keys() if "decoder" in name]
for name in decoder_key_names:
del state_dict[name]
logging.info("Loading state dict into CustomOutputT5")
self.t5_model.model.load_state_dict(state_dict, strict=False)