Source code for xpmir.conversation.learning

from functools import cached_property
from typing import Iterator, List

import numpy as np
from datamaestro.record import Record
from datamaestro_text.data.conversation import (
    ConversationDataset,
    ConversationHistoryItem,
    EntryType,
)
from experimaestro import Config, Param

from xpmir.learning.base import BaseSampler, SampleIterator
from xpmir.utils.iter import RandomSerializableIterator


[docs]class DatasetConversationBase(Config): datasets: Param[List[ConversationDataset]] """The conversation datasets""" @cached_property def records(self): records = [] for dataset in self.datasets: for conversation in dataset.__iter__(): nodes = [ node for node in conversation if node.entry[EntryType] == EntryType.USER_QUERY ] for node in nodes: records.append( node.entry.update(ConversationHistoryItem(node.history())) ) return records
[docs]class DatasetConversationIterator(SampleIterator, DatasetConversationBase): def __iter__(self) -> Iterator[Record]: yield from self.records
[docs]class DatasetConversationEntrySampler(BaseSampler, DatasetConversationBase): """Uses a conversation dataset and topic records entries""" def __iter__(self) -> RandomSerializableIterator[Record]: return RandomSerializableIterator(self.random, self.get_iterator) def get_iterator(self, random: np.random.RandomState): return DatasetConversationEntrySamplerIterator(self, random)
class DatasetConversationEntrySamplerIterator(Iterator[Record]): def __init__( self, sampler: DatasetConversationEntrySampler, random: np.random.RandomState ): self.sampler = sampler self.random = random def __next__(self): if self.random is None: raise ValueError( "Random state is not initialized. Call the iterator first." ) return self.sampler.records[self.random.randint(0, len(self.sampler.records))]