Source code for xpmir.conversation.models.cosplade

from typing import List, Tuple, Optional
from attr import define
from datamaestro_text.data.conversation.base import EntryType
import torch
import sys
from experimaestro import Param
from datamaestro.record import Record
from datamaestro_text.data.ir import TextItem
from datamaestro_text.data.conversation import (
    AnswerEntry,
    ConversationHistoryItem,
)
from xpmir.conversation.learning.reformulation import (
    ConversationRepresentationEncoder,
)
from xpmir.text.encoders import (
    RepresentationOutput,
    TextsRepresentationOutput,
)
from xpmir.letor.trainers.alignment import AlignmentLoss
from xpmir.neural.splade import SpladeTextEncoderV2
from xpmir.utils.logging import easylog

logger = easylog()


@define
class CoSPLADEOutput(RepresentationOutput):
    q_queries: torch.Tensor
    q_answers: torch.Tensor


[docs]class AsymetricMSEContextualizedRepresentationLoss( AlignmentLoss[CoSPLADEOutput, TextsRepresentationOutput] ): """Computes the asymetric loss for CoSPLADE""" def __call__(self, input: CoSPLADEOutput, target: TextsRepresentationOutput): # Builds up the list of tokens in the gold output ids = target.tokenized.ids.cpu() sources = [] tokens = [] for ix, (ids, length) in enumerate( zip(target.tokenized.ids, target.tokenized.lens) ): for token_id in set(ids[:length]): sources.append(ix) tokens.append(token_id) # Compute difference on selected tokens difference = torch.nn.functional.mse_loss( input.value[sources, tokens], target.value[sources, tokens], reduction="none", ) loss = torch.zeros( len(target.value), dtype=target.value.dtype, device=target.value.device ) # Aggregate sources_pt = torch.tensor(sources, device=target.value.device, dtype=torch.long) return loss.scatter_add(0, sources_pt, difference).mean()
[docs]class CoSPLADE(ConversationRepresentationEncoder): """CoSPLADE model""" history_size: Param[int] = 0 """Size of history to take into account (0 for infinite)""" queries_encoder: Param[SpladeTextEncoderV2[List[List[str]]]] """Encoder for the query history (the first one being the current one)""" history_encoder: Param[SpladeTextEncoderV2[Tuple[str, str]]] """Encoder for (query, answer) pairs""" def __initialize__(self, options): super().__initialize__(options) self.queries_encoder.initialize(options) self.history_encoder.initialize(options) @property def dimension(self): return self.queries_encoder.dimension def forward(self, records: List[Record]): queries: List[List[str]] = [] query_answer_pairs: List[Tuple[str, str]] = [] pair_origins: List[int] = [] history_size = self.history_size or sys.maxsize # Process each topic record for ix, c_record in enumerate(records): # Adds q_n, q_1, ..., q_{n-1} queries.append( [c_record[TextItem].text] + [ entry[TextItem].text for entry in c_record[ConversationHistoryItem].history if entry[EntryType] == EntryType.USER_QUERY ] ) # List of query/answer couples answer: Optional[AnswerEntry] = None for item in c_record[ConversationHistoryItem].history: entry_type = item[EntryType] if entry_type == EntryType.USER_QUERY and answer is not None: query_answer_pairs.append((item[TextItem].text, answer.answer)) pair_origins.append(ix) if len(pair_origins) >= history_size: break elif entry_type == EntryType.SYSTEM_ANSWER: if (answer := item.get(AnswerEntry)) is None: logger.warning("Answer record has no answer entry") else: # Ignore anything which is not a pair topic-response answer = None # (1) encodes the queries q_queries = self.queries_encoder(queries).value # (2) encodes the past queries and answers (if any) q_answers = torch.zeros_like(q_queries) if query_answer_pairs: x_pairs = self.history_encoder(query_answer_pairs).value q_ix = torch.tensor(pair_origins, dtype=torch.long, device=q_queries.device) q_ix = q_ix.unsqueeze(-1).expand(x_pairs.shape) q_answers.scatter_add_(0, q_ix, x_pairs) return CoSPLADEOutput(q_queries + q_answers, q_queries, q_answers)