from abc import ABC, abstractmethod
from typing import List, NamedTuple, Optional, TypeVar, Generic, Sequence, Union
from attr import define
import re
import torch
from experimaestro import Config
from xpm_torch.utils.utils import Initializable
from xpm_torch.utils import to_device
from xpmir.text.utils import lengthToMask
T = TypeVar("T")
def opt_slice(
x: Optional[Sequence[T]], ix: Union[int, slice, list]
) -> Optional[Sequence[T]]:
if x is None:
return None
if isinstance(x, torch.Tensor):
return x[ix]
if isinstance(ix, list):
return [x[i] for i in ix]
return x[ix]
class TokenizedTexts(NamedTuple):
"""Structured Tokenized texts output, that can be cast to device"""
tokens: Optional[List[List[str]]]
"""The list of tokens"""
ids: torch.LongTensor
"""A matrix containing the token IDs"""
lens: List[int]
"""the lengths of each text (in tokens)"""
mask: Optional[torch.LongTensor]
"""The mask for the ids matrix"""
token_type_ids: Optional[torch.LongTensor] = None
"""Type of each token"""
def __len__(self):
return len(self.ids)
def __getitem__(self, ix):
return TokenizedTexts(
opt_slice(self.tokens, ix),
self.ids[ix],
opt_slice(self.lens, ix),
opt_slice(self.mask, ix),
opt_slice(self.token_type_ids, ix),
)
def to(self, device: torch.device):
if device is self.ids.device:
return self
return TokenizedTexts(
self.tokens,
self.ids.to(device),
self.lens,
to_device(self.mask, device),
to_device(self.token_type_ids, device),
)
def subset(self, ids: torch.Tensor) -> "TokenizedTexts":
"""Returns a subset of the tokenized texts
Arguments:
ids: a boolean mask or a list of indices to select
"""
length = torch.sum(ids).item() if isinstance(ids, torch.Tensor) else len(ids)
return TokenizedTexts(
[self.tokens[i] for i in ids] if self.tokens else None,
self.ids[:, ids],
[length for i in self.lens],
self.mask[:, ids] if self.mask is not None else None,
self.token_type_ids[:, ids] if self.token_type_ids is not None else None,
)
[docs]
class Tokenizer(Config):
"""
Represents a vocabulary and a tokenization method
**Deprecated**: Use TokenizerBase instead
"""
def tokenize(self, text: str):
"""
Meant to be overwritten in to provide vocab-specific tokenization when necessary
e.g., BERT's WordPiece tokenization
"""
text = text.lower()
text = re.sub(r"[^a-z0-9]", " ", text)
return text.split()
def pad_sequences(self, tokensList: List[List[int]], batch_first=True, maxlen=0):
padding_value = 0
lens = [len(s) for s in tokensList]
if maxlen is None:
maxlen = max(lens)
else:
maxlen = min(maxlen or 0, max(lens))
if batch_first:
out_tensor = torch.full(
(len(tokensList), maxlen), padding_value, dtype=torch.long
)
for i, tokens in enumerate(tokensList):
out_tensor[i, : lens[i], ...] = torch.LongTensor(tokens[:maxlen])
else:
out_tensor = torch.full(
(maxlen, len(tokensList)), padding_value, dtype=torch.long
)
for i, tokens in enumerate(tokensList):
out_tensor[: lens[i], i, ...] = tokens[:maxlen]
return out_tensor.to(self._dummy_params.device), lens
[docs]
def batch_tokenize(
self, texts: List[str], batch_first=True, maxlen=None, mask=False
) -> TokenizedTexts:
"""
Returns tokenized texts
Arguments:
mask: Whether a mask should be computed
"""
toks = [self.tokenize(text) for text in texts]
tokids, lens = self.pad_sequences(
[[self.tok2id(t) for t in tok] for tok in toks],
batch_first=batch_first,
maxlen=maxlen,
)
_mask = lengthToMask(torch.LongTensor(lens)) if mask else None
return TokenizedTexts(toks, tokids, lens, _mask)
@property
def pad_tokenid(self) -> int:
raise NotImplementedError()
[docs]
def tok2id(self, tok: str) -> int:
"""
Converts a token to an integer id
"""
raise NotImplementedError()
[docs]
def id2tok(self, idx: int) -> str:
"""
Converts an integer id to a token
"""
raise NotImplementedError()
[docs]
def lexicon_size(self) -> int:
"""
Returns the number of items in the lexicon
"""
raise NotImplementedError()
TokenizerInput = TypeVar("TokenizerInput")
TokenizerOutput = TypeVar("TokenizerOutput", bound=TokenizedTexts)
@define
class TokenizerOptions:
max_length: Optional[int] = None
return_mask: Optional[bool] = True
return_length: Optional[bool] = True
[docs]
class TokenizerBase(
Config, Initializable, Generic[TokenizerInput, TokenizerOutput], ABC
):
"""Base tokenizer"""
def __initialize__(self):
super().__initialize__()
@abstractmethod
def tokenize(
self, inputs: TokenizerInput, options: Optional[TokenizerOptions] = None
) -> TokenizerOutput:
"""Encodes the inputs"""
...
@abstractmethod
def vocabulary_size(self) -> int:
"""
Returns the number of tokens
"""
...
@abstractmethod
def tok2id(self, tok: str) -> int:
"""
Converts a token to an integer id
"""
...
@abstractmethod
def id2tok(self, idx: int) -> str:
"""
Converts an integer id to a token
"""
...