Source code for xpmir.mlm.samplers

from typing import List, Optional

import numpy as np

from datamaestro_text.data.ir import DocumentStore

from xpmir.letor.records import DocumentRecord
from xpmir.letor.samplers import Sampler
from experimaestro import Param

from xpmir.utils.iter import RandomSerializableIterator
from xpmir.utils.utils import easylog
import logging

logger = easylog()

logger = logging.getLogger()
logger.setLevel(logging.INFO)


[docs]class MLMSampler(Sampler): """Sample texts from various sources This sampler can be used for Masked Language Modeling to sample from several datasets. """ datasets: Param[List[DocumentStore]] """Lists of datasets to sample from""" _stores: List[DocumentStore] def initialize(self, random: Optional[np.random.RandomState]): super().initialize(random) def record_iter(self) -> RandomSerializableIterator[DocumentRecord]: def iter(random: np.random.RandomState): while True: # We could imagine setting weights here to give more importance # to one dataset document_count = [dataset.documentcount for dataset in self.datasets] choice = random.randint(0, len(self.datasets)) if document_count[choice] < 10_000_000: document = self.datasets[choice].document_int( self.random.randint(0, self.datasets[choice].documentcount) ) yield document else: # FIXME: it makes the iter not fully serializable yield from self.datasets[choice].iter() return RandomSerializableIterator(self.random, iter)