import logging
import dataclasses
from transformers import AutoConfig, AutoTokenizer, T5ForConditionalGeneration, T5Config
from experimaestro import Param, LightweightTask
from typing import Optional, List, NamedTuple, Tuple
import torch
from torch import nn
from xpmir.learning import ModuleInitOptions, ModuleInitMode
from xpmir.text.encoders import TokenizedTexts
from xpmir.distributed import DistributableModel
from . import (
ConditionalGenerator,
GenerateOptions,
BeamSearchGenerationOptions,
StepwiseGenerator,
)
class GeneratorForwardOutput(NamedTuple):
"""The forward output of the generative retrieval"""
logits: torch.tensor
past_key_values: Optional[torch.tensor] = None
class FullSequenceGenerationOutput(NamedTuple):
"""The output for the generate method"""
sequences: torch.tensor
"""The returned sequence
shape: [bs*num_sequence, max_depth]"""
output_mask: torch.tensor
"""A mask for the output sequences
shape: [bs*num_sequence, max_depth]"""
transition_scores: Optional[torch.tensor] = None
"""The condtional proba for tokens in the sequences, log, normalized
shape: [bs * num_sequence, max_depth]"""
all_scores: Optional[Tuple[torch.tensor]] = None
"""All the probabilities, log, normalized, tuple of length max_depth
each tensor of the tuple has the shape of [bs * num_sequence, vs]"""
sequence_scores: Optional[torch.tensor] = None
"""The proba for the full sequence, log
shape: [bs * num_sequence]"""
[docs]class T5ConditionalGenerator(ConditionalGenerator, DistributableModel):
hf_id: Param[str]
"""The HuggingFace identifier (to configure the model)"""
def stepwise_iterator(self) -> StepwiseGenerator:
return T5StepwiseGenerator(self)
def __initialize__(self, options: ModuleInitOptions):
assert options.mode != ModuleInitMode.RANDOM, "Random mode not handled (yet)"
super().__initialize__(options)
# Easy and hacky way to get the device
self._dummy_params = nn.Parameter(torch.Tensor())
self.tokenizer = AutoTokenizer.from_pretrained(self.hf_id, use_fast=True)
self.config = AutoConfig.from_pretrained(self.hf_id)
self.model = self.initialize_model(options)
self.pad_token_id = self.model.config.pad_token_id
self.decoder_start_token_id = self.model.config.decoder_start_token_id
self.eos_token_id = self.model.config.eos_token_id
self.encoder = self.model.get_encoder()
def initialize_model(self, options: ModuleInitOptions):
return T5ForConditionalGeneration(self.config)
@property
def device(self):
return self._dummy_params.device
def batch_tokenize(
self,
texts: List[str],
batch_first=True,
maxlen=None,
mask=False,
) -> TokenizedTexts:
"""Tokenize the input text"""
if maxlen is None:
maxlen = self.tokenizer.model_max_length
else:
maxlen = min(maxlen, self.tokenizer.model_max_length)
assert batch_first, "Batch first is the only option"
r = self.tokenizer(
list(texts),
max_length=maxlen,
truncation=True,
padding=True,
return_tensors="pt",
return_length=True,
return_attention_mask=mask,
)
return TokenizedTexts(
None,
r["input_ids"].to(self.device),
r.get("length", None),
r["attention_mask"].to(self.device) if mask else None,
r.get("token_type_ids", None), # if r["token_type_ids"] else None
)
def encode(self, texts: List[str]):
"""Returns the encoder_output and the input mask for the given text,
which could accelerate the autoregressive generation procedure"""
tokenized = self.batch_tokenize(texts, maxlen=512, mask=True)
encoder_output = self.encoder(
tokenized.ids,
attention_mask=tokenized.mask,
return_dict=True,
)
return encoder_output, tokenized.mask
def forward(
self,
encoder_attention_mask, # shape [bs, seq] with 0 or 1
encoder_outputs,
decoder_input_ids=None, # if given, shape [bs]
past_key_values=None,
):
"""Get the logits from the decoder"""
bs = encoder_outputs.last_hidden_state.shape[0]
if past_key_values is None:
decoder_input_ids = (
torch.ones((bs, 1), dtype=torch.long).to(self.device)
* self.decoder_start_token_id
)
else:
if decoder_input_ids is None:
raise ValueError("decoder_input_ids of the previous layer is not given")
else:
decoder_input_ids = decoder_input_ids.unsqueeze(-1)
# Do a forward pass to get the next token
# returns three not None values:
# past_key_values, last_hidden_state, encoder_last_hidden_state
decoder_output = self.model(
decoder_input_ids=decoder_input_ids,
encoder_outputs=encoder_outputs,
attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
use_cache=True,
return_dict=True,
)
logits = decoder_output.logits[:, -1, :] # shape [bs, decoder_outdim+1]
return GeneratorForwardOutput(
logits=logits, past_key_values=decoder_output.past_key_values
)
def generate(
self, inputs: List[str], options: GenerateOptions = None
) -> FullSequenceGenerationOutput:
inputs = self.batch_tokenize(inputs, mask=True)
generate_options_kwargs = dataclasses.asdict(options)
if isinstance(options, BeamSearchGenerationOptions):
res = self.model.generate(
input_ids=inputs.ids,
attention_mask=inputs.mask,
**generate_options_kwargs,
)
else:
raise NotImplementedError(
f"Generation Options not supported for {options.__class__}"
)
if options.return_dict_in_generate:
output_mask = torch.where(res.sequences != self.pad_token_id, 1, 0).to(
self.device
)
if self.pad_token_id == self.decoder_start_token_id:
output_mask[:, 0] = 1
if not options.output_scores:
return FullSequenceGenerationOutput(
sequences=res.sequences,
output_mask=output_mask,
)
else:
# -- For the old version should be compute_transition_beam_scores
transition_scores = self.model.compute_transition_scores(
res.sequences,
res.scores,
res.beam_indices,
normalize_logits=False, # for bs the logits are already normalized
)
sequence_scores = torch.sum(transition_scores, dim=-1)
return FullSequenceGenerationOutput(
sequences=res.sequences,
output_mask=output_mask,
transition_scores=transition_scores,
sequence_scores=sequence_scores,
)
else:
output_mask = torch.where(res != self.pad_token_id, 1, 0).to(self.device)
if self.pad_token_id == self.decoder_start_token_id:
output_mask[:, 0] = 1
return FullSequenceGenerationOutput(
sequences=res,
output_mask=output_mask,
)
def batch_decode(self, generate_output: FullSequenceGenerationOutput) -> List[str]:
"""Decode the sequences to meaningful texts"""
return self.tokenizer.batch_decode(
generate_output.sequences, skip_special_tokens=True
)
def distribute_models(self, update):
self.encoder = update(self.model.get_encoder())
self.model = update(self.model)
class T5StepwiseGenerator(StepwiseGenerator):
def __init__(self, id_generator: ConditionalGenerator):
super().__init__()
# The identifier to use to generate the next step's token
self.id_generator = id_generator
def init(self, texts: List[str]):
"""Initialize some inner states for further iterations, and return
the initial decoder input tokens"""
self.encoder_output, self.attention_mask = self.id_generator.encode(texts)
self.past_key_values = None
def step(self, decoder_input_tokens) -> torch.Tensor: # input shape [bs]
"""Returns the distribution over next tokens (BxV) by performing a
stepwise iteration"""
forward_output: GeneratorForwardOutput = self.id_generator(
self.attention_mask,
self.encoder_output,
decoder_input_tokens,
past_key_values=self.past_key_values,
)
self.past_key_values = forward_output.past_key_values
return forward_output.logits
class T5ForIdentifierGeneration(T5ForConditionalGeneration):
"""T5-based identifier generation
The class modifies T5 to use a custom vocabulary in the decoder
"""
def __init__(self, config: T5Config, decoder_outdim: int):
# not including the eos and pad
self.decoder_outdim = decoder_outdim
# modification of the config according to our needs
config.pad_token_id = self.decoder_outdim + 1
config.decoder_start_token_id = self.decoder_outdim + 1
config.eos_token_id = self.decoder_outdim
# Keep config at hand
self.config = config
super().__init__(self.config)
# Modify LM head
self.lm_head = nn.Linear(
self.lm_head.in_features, self.decoder_outdim + 1, bias=False
)
# We have one more token (PAD when )
encoder_embeddings = nn.Embedding(self.config.vocab_size, self.config.d_model)
self.config.vocab_size = self.decoder_outdim + 1
self.get_encoder().set_input_embeddings(encoder_embeddings)
# Modify the decoder vocabulary
decoder_embeddings = nn.Embedding(
self.decoder_outdim + 2, self.config.d_model, padding_idx=decoder_outdim + 1
)
self.get_decoder().set_input_embeddings(decoder_embeddings)
def forward(self, **kwargs):
return super().forward(**kwargs)
[docs]class T5IdentifierGenerator(T5ConditionalGenerator):
"""generate the id of the token based on t5-based models"""
decoder_outdim: Param[int] = 10
"""The decoder output dimension for the t5 model, use it to
rebuild the lm_head and the decoder embedding, this number
doesn't include the pad token and the eos token
"""
def initialize_model(self, options: ModuleInitOptions):
return T5ForIdentifierGeneration(self.config, self.decoder_outdim)
class T5ForConditionalCustomGeneration(T5ForConditionalGeneration):
"""T5-based model with custom output"""
def __init__(self, config: T5Config, decoder_outdim: int):
# not including the eos and pad
self.config = config
self.decoder_outdim = decoder_outdim
config.decoder_start_token_id = self.decoder_outdim - 1
super().__init__(self.config)
# Modify LM head
self.lm_head = nn.Linear(
self.lm_head.in_features, self.decoder_outdim, bias=False
)
encoder_embeddings = nn.Embedding(self.config.vocab_size, self.config.d_model)
self.get_encoder().set_input_embeddings(encoder_embeddings)
# Modify the decoder vocabulary
self.config.vocab_size = self.decoder_outdim
decoder_embeddings = nn.Embedding(self.decoder_outdim, self.config.d_model)
self.get_decoder().set_input_embeddings(decoder_embeddings)
def forward(self, **kwargs):
return super().forward(**kwargs)
[docs]class T5CustomOutputGenerator(T5ConditionalGenerator):
"""generate the id of the token based on t5-based models"""
#: List of tokens for the output
tokens: Param[List[str]]
def initialize_model(self, options: ModuleInitOptions):
return T5ForConditionalCustomGeneration(self.config, len(self.tokens))
[docs]class LoadFromT5(LightweightTask):
"""Load parameters from a T5 model"""
t5_model: Param[T5ConditionalGenerator]
"""the target"""
def execute(self):
self.t5_model.initialize(ModuleInitMode.DEFAULT.to_options())
# Load from checkpoint
logging.info("Loading hugginface T5 from checkpoint %s", self.t5_model.hf_id)
# Load the T5 pre-trained model
t5_model = T5ForConditionalGeneration.from_pretrained(self.t5_model.hf_id)
# Change the state_dict for the lm_head the decoder embedding
state_dict = t5_model.state_dict()
if isinstance(self.t5_model, T5IdentifierGenerator):
# Just forget about the weights
del state_dict["lm_head.weight"]
# use random initialized t5 decoder
decoder_key_names = [
name for name in state_dict.keys() if "decoder" in name
]
for name in decoder_key_names:
del state_dict[name]
elif isinstance(self.t5_model, T5CustomOutputGenerator):
# Get the token embeddings from the tokenizer
token_ids = []
for token in self.t5_model.tokens:
ids = self.t5_model.tokenizer.encode(token, add_special_tokens=False)
if len(ids) != 1:
raise ValueError(f"Token {token} is made of {len(ids)} subtokens")
token_ids.append(ids[0])
# And restrict our dictionary to the possible tokens
state_dict["lm_head.weight"] = t5_model.lm_head.weight.detach()[
(tuple(token_ids),)
]
state_dict[
"decoder.embed_tokens.weight"
] = t5_model.lm_head.weight.detach()[(tuple(token_ids),)]
logging.info("Loading state dict into the custom T5")
self.t5_model.model.load_state_dict(state_dict, strict=False)