Source code for xpmir.conversation.learning

from functools import cached_property
from datamaestro.record import Record
import numpy as np
from datamaestro_text.data.conversation import (
    ConversationDataset,
    ConversationHistoryItem,
    EntryType,
)
from experimaestro import Param

from xpmir.learning.base import BaseSampler
from xpmir.utils.iter import RandomSerializableIterator


[docs]class DatasetConversationEntrySampler(BaseSampler): """Uses a conversation dataset and topic records entries""" dataset: Param[ConversationDataset] """The conversation dataset""" @cached_property def conversations(self): return list(self.dataset.__iter__()) def __post_init__(self): super().__post_init__() def __iter__(self) -> RandomSerializableIterator[Record]: def generator(random: np.random.RandomState): while True: # Pick a random conversation conversation_ix = random.randint(0, len(self.conversations)) conversation = self.conversations[conversation_ix] # Pick a random topic record entry nodes = [ node for node in conversation if node.entry()[EntryType] == EntryType.USER_QUERY ] node_ix = random.randint(len(nodes)) node = nodes[node_ix] node = node.entry().update(ConversationHistoryItem(node.history())) yield node return RandomSerializableIterator(self.random, generator)