任务型对话系统中Neural Pipeline系列源码剖析

Table of Contents

本文于<2022-03-26 Sat>进行了修改.

本文主要是基于github上的SOLOIST和AuGPT两份源码(here)进行的阅读. 这两份源码都是Neural Pipeline相关的, 因此笔者就基于该源码统一的进行分析了.

在阅读源码之前, 可能还是需要首先了解一下什么是任务型对话系统(TOD, task-oriented dialogue system)以及什么是neural pipeline的方法,如果我忘记了这两个概念,我会去阅读 这篇 笔记.

然后,这篇文章就着重于描述上述模型的代码实现了!

1. 总体结构

总体上而言, 任何一个深度学习的代码无外乎以下几个主要部分:

  • 数据预处理: 解析和读取实验数据, 对实验数据进行一系列变换, 并最终整理为格式化的输入. 对于pytorch而言, 所谓的格式化输入相当于整理得到可供dataloader加载的数据,在具体传入模型的形式上, 除去一部分控制参数可以拥有 bool或scale的数值类型之外,绝大多数数据类型都应是一个张量,也就是tensor. 将不同维度的向量通过padding形成一个batch层面的张量是NLP中必须考虑的一环.
  • 深度学习模型代码. 当下绝大多数深度学习模型都不是model free的,所以模型的设计几乎不可避免. 当然,由于预训练模型的兴起,深度学习模型部分的代码已经没有多少代码量了. 本文主要还是基于GPT-2这样一些transformer家族的模型进行处理,因此在模型结构设计上没有创新,可以直接复用hugging face, 如果我对设计transformer模型感兴趣, 这篇 笔记记录了如何阅读并修改transformers源码, 我可以去看一下.
  • 训练部分的代码. 按照传统有监督学习的框架, 一个有监督学习主要包括data, model, loss function, optimization等四个基本组件, 由于data和model的出奇的重要性, 所以这两块被拿出来单独陈述了,而剩下的loss和optimization, 其实都可以放入训练部分的代码中. 一般而言,训练部分的代码便是这样的"胶水":用以记录如何进行这样的forward和backward的传播, 存储实验结果和模型参数, 进行可视化,等等等等.
  • 模型测试环节的代码实现. 模型测试环节的代码实现和训练时大同小异,但对于NLG任务而言, 二者还是有较多的区别, 因此此处单独拿出来作为新的一部分.笔者在对已有的代码进行魔改时,曾经就是忘记了修改模型测试部分的代码而浪费了一周的宝贵时间. 在模型测试部分, 主要是设计prefix进行自回归的生成, 以及如何对生成的语句进行基于规则的后处理,最终将后处理的文本与ground truth 进行评估,以产生最终的结果.

2. Data部分

对于深度学习,尤其是transformer时代的深度学习而言, 当下的创新多从训练任务和整理数据入手,而非模型的结构. 笔者的同学甚至因为在面向某个NLP任务时自己设计了特殊的结构而被审稿人质疑. neural pipeline的方法, 其核心思路是将TOD中的各个基本组件都统一地放在单一的一个NLG模型中,所以, 对数据的编排就不得不算是重中之重了.

下文将先介绍GPT-2的输入输出具体是什么风格, 之后介绍TOD的各基本模块的输入输出为什么, 最后在上述二者的基础上, 探究怎么用GPT-2做基本模块的输入. 当然,这里的问题是:为什么一定要用GPT-2 ? 其实不用GPT-2也是完全可以的, T5, BART等模型一样可以解决该问题, 此处单以GPT-2为示例,其他情况灵活变通.

2.1. GPT-2的输入形态

GPT-2的输入严格而言主要包括三部分:

  1. token encoding. 可以理解为是输入语句的语义信息输入. 实际上传入模型中的每一条数据都是index的列表. 相当于对于输入的一句话,可以通过tokenizer将其切分为若干个token(根据不同的tokenize方法而产生word 或 subword),而该token都可以唯一的与vocablist中的某一行对应,该行的行号便可以看作是这里的index. 而对于GPT-2这样的深度学习模型而言, 这里的index并不是一个数字, 而是可以理解为1个one-hot向量, 该向量仅仅在index处为1,其他地方为0. GPT-2会基于这样的one-hot向量, 使用一个早就学习好的embedding layer(其实就是1个全连接)将这个稀疏的向量映射为一个稠密的真值向量(比如从5w个词汇到768维), 后者便是真正的token embedding. 当然, transformer库也可以允许你直接传入embedding的列表来表示1句话,这样这个embedding层就不会生效.
  2. position encoding. 由于self-attention天生的无位置性,所以需要加入位置编码信息来确定顺序. GPT-2采用的是绝对位置编码,一般而言相当于从0到max_seq_len-1内的任何一个数字作为输入,之后会使用三角函数将这个数转化为1个embedding.
  3. type encoding. 用于区分每个token的类型. 比如A和B聊天, 那么就可以对这二者设置两种不同的类型.

不妨通过 GPT2LMHeadModel 中的 forward 函数进行简单的理解:

def forward(
    self,
    input_ids=None, # token encoding
    past=None,
    attention_mask=None,
    token_type_ids=None, # type encoding
    position_ids=None, # position encoding
    head_mask=None,
    inputs_embeds=None,
    labels=None,
    use_cache=None,
    output_attentions=None,
    output_hidden_states=None,
):

注意到此处还有一些别的函数, 这些函数在 这篇 笔记中会有详细介绍.

我们可以观察一下token encoding的映射embedding在GPT-2中的实现:

class GPT2Model(GPT2PreTrainedModel):
    def __init__(self, config):
        super().__init__(config)

        self.wte = nn.Embedding(config.vocab_size, config.n_embd)   # token encoding-->token embedding layer
        self.wpe = nn.Embedding(config.n_positions, config.n_embd)# position encoding-->position embedding layer
        self.drop = nn.Dropout(config.embd_pdrop)
        self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)])
        self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)

        self.init_weights()

    def get_input_embeddings(self):
        return self.wte

    def set_input_embeddings(self, new_embeddings):
        self.wte = new_embeddings

可以发现3.0.2版本的GPT-2, 在实现position embedding时并没有使用三角函数, 也是使用的pytorch中的embedding层.

关于GPT-2的文本输出,则更为简单. GPT-2在经过最末层的transformer模块之后,会进入一个线性分类器层W, 从而从embedding维度被映射到vocab length. 如果上述公式被表达为Y=W*d, 那么W的每一行都可以看作是vocab的表示, 即token embedding, 而映射后的Y则是vocab的分布. 最后生成的token将基于这个分布进行采样. 比如greedy就会是argmax,即直接取最大值对应的index,去查表找到这个vocab. 对上述过程的直观理解就是: Y可以看作是d与各token表示的相似度构成的向量,选取相似度最大的那个token. 这也是CV中比较传统的pattern matching的思想.

2.2. Neural Pipeline中的文本输入与输出

下面来看任务型对话系统Neural Pipeline的输入与输出形态.

模块 输入 输出
NLU+DST history belief state
database retrievaling belief state database matching result
decision (i.e. dialogue policy) belief state and database matching result dialogue action
NLG case1 dialogue action delexicalized response
NLG case2 belief state and database matching result delexicalized response
lexicalizing delexicalized response belief state datbase response

上表简单展示了大体上的各模块的输入输出,其中NLU+DST, decision, NLG三个主要部分是Neural Pipeline在进行, 而数据库的的检索和slot填值都可以简单地基于规则进行解决. 此处不会再去介绍每个输入和输出的数据都是什么意思, 但是可以简单写一下上述各个数据的数据结构:

data structure type
history Vec<String>
belief state HashMap<Domain, HashMap<Slot, Value>>
database matching result HashMap<Domain, Int32>
dialogue action Vec<(Intent, Domain, Slot, Value)>
delexicalized response String
response String

类型中的domain, slot, value都是什么类型呢? 其实和任务型对话系统的设计有关, multiwoz比较玩具,就都是字符串了.

所可以看到, 把上述的结构全部转化为连续的字符串序列是利用预训练模型的第一步, 这个过程在前后端交互中似乎叫serial和deserial. 下面介绍之.

2.3. neural pipline与GPT-2的结合

我们最终的目标是用GPT-2来完成一个pipeline, 所以相当于用history做输出输出belief state, 之后检索数据库得到datbase matching result, 用三者做输入去产生actions, 最终产生delexicalized response.

整体架构可用下图表达:

screenshot_20220326_173041.png

serial的过程, 依照不同的代码实现有所不同, 主要而言其实包括两个大的部分:

  1. 如何将结构化信息表述成一个序列;
  2. 如何对各个子模块进行分隔.

我们可以看一下源码中的实现:

关于如何serial belief state:

def format_belief(belief: OrderedDict) -> str:
    assert isinstance(belief, OrderedDict)
    str_bs = []
    for domain, domain_bs in belief.items():
        domain_bs = ', '.join([f'{slot} = {val}' for slot, val in sorted(domain_bs.items(), key=lambda x: x[0])])
        str_bs.extend([domain, '{' + domain_bs + '}'])
    return ' '.join(str_bs)

关于如何serial dialogue acts(这个是我自己写的):

def format_dialogue_act(acts) -> str:
    # assert isinstance(acts, List[List[str]])
    str_da = []
    # intent-domain-slot-value
    str_acts=""
    for act in acts:
        intent,domain,slot,value=act
        # str_acts+=f"{intent}, {domain}, {slot}, {value}; "
        str_acts+=f"{intent}, {domain}, {slot}; "
    return str_acts[:-2]

关于如何serial dialogue acts:

def default_translate_match(n):
    if n == 0:
        return 'no match'
    if n == 1:
        return '1 match'
    return f'{n} matches'

以及如何拼接history!

    class InsertLabelsTransformation:
        user_label: str = 'User :'
        sys_label: str = 'System :'
        database_label: str = 'DB :'
        belief_label: str = 'Belief state :'
        dialogue_act_label: str = 'Action :'
        # template_label: str = "Template :"

        def __call__(self, sample: DialogDatasetItem) -> DialogDatasetItem:
            if isinstance(sample, tuple):
                sample = DialogDatasetItem(*sample)
            # Transform context
#######################################################################
            context = sample.context
            context = list(context)
            labels = self.user_label, self.sys_label
            for i in range(len(context) - 1, -1, -1):
                label, other = labels
                context[i] = label + ' ' + context[i]
                labels = other, label
            context = ' '.join(context)
#######################################################################

            # Database
            database = sample.database
            if database is not None:
                database_str = []
                for database_domain, database_count in database.items():
                    database_str.append(database_domain + ' ' +
                                        default_translate_match(database_count))
                database = self.database_label + ' ' + ' , '.join(database_str)

            # Belief state
            belief = sample.belief
            if belief is not None:
                belief = self.belief_label + ' ' + belief

            # dialogue act
            if sample.dialogue_act is not None:
                dialogue_act=sample.dialogue_act
                dialogue_act=self.dialogue_act_label+" "+dialogue_act
            else:
                dialogue_act=None

            # # template
            # template=sample.template
            # if template is not None:
            #     template=self.template_label+" "+ str()
            # return dataclasses.replace(sample, belief=belief,
                                       # database=database, context=context,template=template)

            return dataclasses.replace(sample, belief=belief,
                                       database=database,
                                       dialogue_act=dialogue_act,
                                       context=context)

上述代码中笔者注释的两行井号中间,就是serial history的部分.

完成了上述活动,下一个问题就是,如何进行各个组件的拼接. 各个组件的拼接其实主要是加入一些special token, 当然,上面的Transformation类也插入了一些提示信息,但是这个还不够彻底,还需要加入一些speicaltoken,最终才连接在一起.

对speical token的定义比较简单,直接修改tokenizer:

EOB_TK = '<|eob|>'
EOKB_TK = '<|eokb|>'
EOT_TK = '<|endoftext|>'

SPECIAL_TOKENS = [EOB_TK, EOKB_TK,
                  "<|pd|>","<|pb|>","<|pc|>",
                  "<|pa|>","<|eoda|>","<|eo_turn|>"]

def add_custom_tokens(tokenizer, model):
    tokenizer.add_special_tokens({'additional_special_tokens': SPECIAL_TOKENS})
    model.resize_token_embeddings(len(tokenizer))
    return tokenizer, model

其中p开头的和eo都不属于neural pipline的范畴, 是我的工作的一部分……

之后定义将他们进行拼接的操作:

class TokenizerTransformation:
    def __init__(self, tokenizer: transformers.GPT2Tokenizer, max_context_length: int = 500, is_bi=False):
        self.bob, self.eob, self.eokb = tokenizer.convert_tokens_to_ids(
            ['=>', '<|eob|>', '<|eokb|>'])

        self.eos = tokenizer.eos_token_id
        self.tokenizer = tokenizer
        self.max_context_length = max_context_length

    def get_tokens(self, data):
        history, belief, database = data.context, data.belief, data.database
        response, positive = data.response, data.positive

        # Add history
        history = self.tokenizer.encode(history)
        inp = history
        labels = [-100 for _ in history]
        context_end = len(labels)

        # Add belief states
        if belief is not None:
            belief = [self.bob] + self.tokenizer.encode(belief) + [self.eob]
            inp += belief
            labels += belief

        belief_end = len(labels)

        # Add database
        if database is not None:
            database = self.tokenizer.encode(database) + [self.eokb]
            inp += database
            labels += [-100 for _ in database]

        database_end = len(labels)

        # Add response
        if response is not None:
            response = self.tokenizer.encode(response) + [self.eos]
            inp += response
            labels += response

        if positive is not None and not positive:
            labels = [-100 for _ in labels]

        if self.max_context_length > 0:

            old_length = len(inp)
            inp = inp[-self.max_context_length:]
            labels = labels[-self.max_context_length:]

            belief_end = belief_end - (old_length - len(inp))
            context_end = context_end - (old_length - len(inp))
            database_end = database_end - (old_length - len(inp))

        return inp, labels, positive, belief_end, context_end, database_end

    # -100 is mask token for LM
    # transforms into dict {"input_ids", "labels", "binary_labels", "binary_token_ids" }
    # binary_labels are used for task 3
    def __call__(self, data):
        inp, labels, positive, belief_end, context_end, database_end = self.get_tokens(data)
        belief_labels = [x if i < belief_end else -100 for i, x in enumerate(labels)]
        response_labels = [x if i >= belief_end else -100 for i, x in enumerate(labels)]
        return dict(input_ids=inp, belief_labels=belief_labels, response_labels=response_labels,
                    consistency_labels=positive, consistency_token_ids=len(labels) - 1)

其中,注意到常常会使用 -100, 这个-100是一个特殊的token, 该token不会被计算损失. 所以上述代码中,作为prefix的history是不会被计算损失的,同理,database的matching结果, 由于是检索得到的,所以其标签也被设置为了-100.

可以看出, 在 get_tokens 里定义了所有的连接操作, 该transformation最终将会返回得到5个结果:

  1. input_ids: 即前篇所说的token encoding
  2. belief_labels: 即belief state的label值,用作监督信号
  3. response_labels: 同理
  4. consistency_labels: 这个是一致性的标签,只包括01两种情况, 后面在任务中会详细介绍
  5. consistency_token_ids 即eos token, 换句话说, eos token的embedding,将会通过一个分类器,进行是否一致的二分类.

上述的5个输入, 将会通过下面的wrapper进行封装,此处的封装便是NLP常用的所需的封装了.

@dataclass
class DataCollatorWithPadding:
    tokenizer: Union[transformers.PreTrainedTokenizer,
                     transformers.PreTrainedTokenizerFast]
    max_length: Optional[int] = None

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor,float]]]) -> Dict[str, torch.Tensor]:
        batch = {
            # "attention_mask":torch.tensor([x["attention_mask"].numpy() for x in features]),
            'consistency_labels': torch.tensor([x['consistency_labels'] for x in features], dtype=torch.float32),
            'consistency_token_ids': torch.tensor([x['consistency_token_ids'] for x in features], dtype=torch.int64),
            # 'input_ids': pad_sequence([torch.tensor(x['input_ids'], dtype=torch.int64) for x in features],
                                      # batch_first=True, padding_value=self.tokenizer.pad_token_id),
            'belief_labels': pad_sequence([torch.tensor(x['belief_labels'], dtype=torch.int64) for x in features],
                                          batch_first=True, padding_value=-100),
        }

        if "position_ids" in features[0]:
            if features[0]["position_ids"] is not None:
                batch["position_ids"]=pad_sequence([x["position_ids"] for x in features],batch_first=True,padding_value=0)
                # print(batch["position_ids"].shape)
                # print(batch["input_ids"].shape)
                # print("=============")

        if "states_ids" in features[0]:
            if features[0]["states_ids"] is not None:
                batch["states_ids"]=pad_sequence([torch.tensor(x["states_ids"],dtype=torch.int64) for x in features],batch_first=True,padding_value=self.tokenizer.pad_token_id)

        if "his_ids" in features[0]:
            if features[0]["his_ids"] is not None:
                batch["his_ids"]=pad_sequence([torch.tensor(x["his_ids"],dtype=torch.int64) for x in features],batch_first=True,padding_value=self.tokenizer.pad_token_id)

        if "response_labels" in features[0]:
            batch["response_labels"]= pad_sequence([torch.tensor(x['response_labels'], dtype=torch.int64) for x in features],batch_first=True, padding_value=-100)
        # else:
        #     batch["response_labels"]=None 

        return batch

这里所采取的便是padding的操作,基于 pad_sequence 函数. 值得注意的是, 此处形成的一个tensor,其shape为batchsize*max_seqlen. max_seqlen是指这个batch中的最长数据的长度, 而非模型的max seq length.

2.4. 再坚持一下: 为什么说一个TOD很复杂?

如果你认为上述部分就是TOD data部分源码的核心, 未免有点小看任务型对话系统的复杂度了. 因为我们其实还没有正式开始介绍全部的流程.

2.4.1. 如何预处理任务型对话系统数据集?

略. 这部分代码我还没开始看,后续补充.

#!/usr/bin/env python
import sqlite3
import json
import os
import tempfile
import re
import shutil
import requests
import random
import logging
import subprocess
import sys
from collections import defaultdict
from copy import deepcopy
import zipfile
import inspect
from collections import OrderedDict, Counter
from tqdm import tqdm
import numpy as np

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils import setup_logging  # noqa: E402


np.set_printoptions(precision=3)
np.random.seed(2)
setup_logging()
logger = logging.getLogger()

# GLOBAL VARIABLES
DATASETS_PATH = os.path.join(os.path.expanduser(os.environ.get('DATASETS_PATH', '~/datasets')), 'soloist')
DICT_SIZE = 400
MAX_LENGTH = 50
DEFAULT_IGNORE_VALUES = ['not mentioned', 'dont care', 'don\'t care', 'dontcare', 'do n\'t care', 'none']
MW_DOMAINS = ['restaurant', 'hotel', 'attraction', 'train', 'taxi', 'hospital', 'police']
digitpat = re.compile(r'\d+')
timepat = re.compile(r"\d{1,2}[:]\d{1,2}")
pricepat2 = re.compile(r"\d{1,3}[.]\d{1,2}")
timepat = re.compile(r"\d{1,2}[:]\d{1,2}")
label_regex = re.compile(r'\[([\w\d\s]+)\]')
pricepat = re.compile(r"\d{1,3}[.]\d{1,2}")
fin = open(os.path.join(os.path.dirname(__file__), 'mapping.pair'), 'r')
replacements = []
for line in fin.readlines():
    tok_from, tok_to = line.replace('\n', '').split('\t')
    replacements.append((' ' + tok_from + ' ', ' ' + tok_to + ' '))


class Lexicalizer:
    def __init__(self, zipf):
        self.path = zipf.filename

    placeholder_re = re.compile(r'\[(\s*[\w_\s]+)\s*\]')
    number_re = re.compile(r'.*(\d+|one|two|three|four|five|six|seven|eight|nine|ten|eleven|twelve)\s$')
    time_re = re.compile(r'((?:\d{1,2}[:]\d{2,3})|(?:\d{1,2} (?:am|pm)))', re.IGNORECASE)

    @staticmethod
    def ends_with_number(s):
        return bool(Lexicalizer.number_re.match(s))

    @staticmethod
    def extend_database_results(database_results, belief):
        # Augment database results from the belief state
        database_results = OrderedDict(database_results)
        if belief is not None:
            for i, (domain, (num_results, results)) in enumerate(database_results.items()):
                if domain not in belief:
                    continue
                if num_results == 0:
                    database_results[domain] = (1, [belief[domain]])
                else:
                    new_results = []
                    for r in results:
                        r = dict(**r)
                        for k, val in belief[domain].items():
                            if k not in r:
                                r[k] = val
                        new_results.append(r)
                    database_results[domain] = (num_results, new_results)
        return database_results

    def __call__(self, text, database_results, belief=None, context=None):
        database_results = Lexicalizer.extend_database_results(database_results, belief)
        result_index = 0
        last_assignment = defaultdict(set)

        def trans(label, span, force=False, loop=100):
            nonlocal result_index
            nonlocal last_assignment
            result_str = None
            current_domain = None
            if '_' in label:
                current_domain = label[:label.index('_')]
                label = label[label.index('_') + 1:]
            if label == 'postcode':
                label = 'post code'

            # No references in the MW 2.0 database
            if label == 'reference':
                return 'YF86GE4J'

            for domain, (count, results) in database_results.items():
                if count == 0:
                    continue
                if current_domain is not None and domain != current_domain and not force:
                    continue
                result = results[result_index % len(results)]
                if label in result:
                    result_str = str(result[label])
                    if result_str == '?':
                        result_str = 'unknown'
                    if label == 'price range' and result_str == 'moderate' and \
                            not text[span[1]:].startswith(' price range') and \
                            not text[span[1]:].startswith(' in price'):
                        result_str = 'moderately priced'
                    if label == 'type':
                        if text[:span[0]].endswith('no ') or text[:span[0]].endswith('any ') or \
                                text[:span[0]].endswith('some ') or Lexicalizer.ends_with_number(text[:span[0]]):
                            if not result_str.endswith('s'):
                                result_str += 's'
                if label == 'time' and ('[leave at]' in text or '[arrive by]' in text) and \
                    belief is not None and 'train' in belief and \
                        any([k in belief['train'] for k in ('leave at', 'arrive by')]):
                    # this is a specific case in which additional [time] slot needs to be lexicalised
                    # directly from the belief state
                    # "The earliest train after [time] leaves at ... and arrives by ..."
                    if 'leave at' in belief['train']:
                        result_str = belief['train']['leave at']
                    else:
                        result_str = belief['train']['arrive by']
                elif force:
                    if label == 'time':
                        if 'leave at' in result or 'arrive by' in result:
                            if 'arrive' in text and 'arrive by' in result:
                                result_str = result['arrive by'].lstrip('0')
                            elif 'leave at' in result:
                                result_str = result['leave at'].lstrip('0')
                        elif context is not None and len(context) > 0:
                            last_utt = context[-1]
                            mtch = Lexicalizer.time_re.search(last_utt)
                            if mtch is not None:
                                result_str = mtch.group(1).lstrip('0')
                if result_str is not None:
                    break
            if force and result_str is None:
                if label == 'reference':
                    result_str = 'YF86GE4J'
                elif label == 'phone':
                    result_str = '01223358966'
                elif label == 'postcode':
                    result_str = 'CB11JG'
                elif label == 'agent':
                    result_str = 'Cambridge Towninfo Centre'
                elif label == 'stars':
                    result_str = '4'

            if result_str is not None and result_str.lower() in last_assignment[label] and loop > 0:
                result_index += 1
                return trans(label, force=force, loop=loop - 1, span=span)

            if result_str is not None:
                last_assignment[label].add(result_str.lower())
            return result_str or f'[{label}]'

        text = Lexicalizer.placeholder_re.sub(lambda m: trans(m.group(1), span=m.span()), text)
        text = Lexicalizer.placeholder_re.sub(lambda m: trans(m.group(1), force=True, span=m.span()), text)
        return text

    def save(self, path):
        shutil.copy(self.path, os.path.join(path, os.path.split(self.path)[-1]))


def clear_whitespaces(text):
    text = re.sub(r'[\s\n\r]+', ' ', text)
    text = ' ' + text + ' '
    text = re.sub(r'\s([,\.:\?\!\']+)', lambda m: m.group(1), text)
    return text.strip()


def insertSpace(token, text):
    sidx = 0
    while True:
        sidx = text.find(token, sidx)
        if sidx == -1:
            break
        if sidx + 1 < len(text) and re.match('[0-9]', text[sidx - 1]) and \
                re.match('[0-9]', text[sidx + 1]):
            sidx += 1
            continue
        if text[sidx - 1] != ' ':
            text = text[:sidx] + ' ' + text[sidx:]
            sidx += 1
        if sidx + len(token) < len(text) and text[sidx + len(token)] != ' ':
            text = text[:sidx + 1] + ' ' + text[sidx + 1:]
        sidx += 1
    return text


def normalize(text):
    # lower case every word
    text = text.lower()

    # replace white spaces in front and end
    text = re.sub(r'^\s*|\s*$', '', text)

    # hotel domain pfb30
    text = re.sub(r"b&b", "bed and breakfast", text)
    text = re.sub(r"b and b", "bed and breakfast", text)

    # normalize phone number
    ms = re.findall(r'\(?(\d{3})\)?[-.\s]?(\d{3})[-.\s]?(\d{4,5})', text)
    if ms:
        sidx = 0
        for m in ms:
            sidx = text.find(m[0], sidx)
            if text[sidx - 1] == '(':
                sidx -= 1
            eidx = text.find(m[-1], sidx) + len(m[-1])
            text = text.replace(text[sidx:eidx], ''.join(m))

    # normalize postcode
    ms = re.findall(
        r'([a-z]{1}[\. ]?[a-z]{1}[\. ]?\d{1,2}[, ]+\d{1}[\. ]?[a-z]{1}[\. ]?[a-z]{1}|[a-z]{2}\d{2}[a-z]{2})',
        text)
    if ms:
        sidx = 0
        for m in ms:
            sidx = text.find(m, sidx)
            eidx = sidx + len(m)
            text = text[:sidx] + re.sub(r'[,\. ]', '', m) + text[eidx:]

    # weird unicode bug
    text = re.sub(u"(\u2018|\u2019)", "'", text)

    # replace time and and price
    text = re.sub(timepat, ' [value_time] ', text)
    text = re.sub(pricepat, ' [value_price] ', text)
    # text = re.sub(pricepat2, '[value_price]', text)

    # replace st.
    text = text.replace(';', ',')
    text = re.sub(r'$\/', '', text)
    text = text.replace('/', ' and ')

    # replace other special characters
    text = text.replace('-', ' ')
    text = re.sub(r'[\":\<>@\(\)]', '', text)

    # insert white space before and after tokens:
    for token in ['?', '.', ',', '!']:
        text = insertSpace(token, text)

    # insert white space for 's
    text = insertSpace('\'s', text)

    # replace it's, does't, you'd ... etc
    text = re.sub(r'^\'', '', text)
    text = re.sub(r'\'$', '', text)
    text = re.sub(r'\'\s', ' ', text)
    text = re.sub(r'\s\'', ' ', text)
    for fromx, tox in replacements:
        text = ' ' + text + ' '
        text = text.replace(fromx, tox)[1:-1]

    # remove multiple spaces
    text = re.sub(' +', ' ', text)

    # concatenate numbers
    tokens = text.split()
    i = 1
    while i < len(tokens):
        if re.match(r'^\d+$', tokens[i]) and \
                re.match(r'\d+$', tokens[i - 1]):
            tokens[i - 1] += tokens[i]
            del tokens[i]
        else:
            i += 1
    text = ' '.join(tokens)
    return text


def fix_active_domain_and_delex(active_domain, text, delex):
    domains = [x.group(1).split('_')[0] for x in label_regex.finditer(delex)]
    domains = [x for x in MW_DOMAINS if x in domains]
    domain_counter = Counter(domains)
    if domain_counter:
        active_domain = domain_counter.most_common(1)[0][0]

    lresponse = text.lower()
    if 'hotel' in lresponse:
        active_domain = 'hotel'
    if 'train' in lresponse or 'arrive' in lresponse or 'leave' in lresponse:
        active_domain = 'train'
    if 'attraction' in lresponse:
        active_domain = 'attraction'
    if 'police' in lresponse:
        active_domain = 'police'
    if 'restaurant' in lresponse or 'food' in lresponse:
        active_domain = 'restaurant'
    if 'hospital' in lresponse:
        active_domain = 'hospital'
    if 'taxi' in lresponse or 'car' in lresponse:
        active_domain = 'taxi'
    taxi_brands = ["toyota", "skoda", "bmw", 'honda', 'ford', 'audi', 'lexus', 'volvo', 'volkswagen', 'tesla']
    if any(t in lresponse for t in taxi_brands):
        active_domain = 'taxi'

    for match in label_regex.finditer(delex):
        domain, slot = match.group(1).split('_')
        if slot == 'reference':
            active_domain = domain

    if active_domain is not None:
        delex = label_regex.sub(lambda x: f'[{active_domain}_{x.group(1).split("_")[1]}]', delex)
    return active_domain, delex


def prepareSlotValuesIndependent(dbzipf, path):
    domains = ['restaurant', 'hotel', 'attraction', 'train', 'taxi', 'hospital', 'police']
    dic = []
    dic_area = []
    dic_food = []
    dic_price = []

    # read databases
    for domain in domains:
        try:
            fin = dbzipf.open(os.path.join('db/' + domain + '_db.json'), 'r')
            db_json = json.load(fin)
            fin.close()

            for ent in db_json:
                for key, val in ent.items():
                    if val == '?' or val == 'free':
                        pass
                    elif key == 'address':
                        dic.append((normalize(val), '[' + domain + '_' + 'address' + ']'))
                        if "road" in val:
                            val = val.replace("road", "rd")
                            dic.append((normalize(val), '[' + domain + '_' + 'address' + ']'))
                        elif "rd" in val:
                            val = val.replace("rd", "road")
                            dic.append((normalize(val), '[' + domain + '_' + 'address' + ']'))
                        elif "st" in val:
                            val = val.replace("st", "street")
                            dic.append((normalize(val), '[' + domain + '_' + 'address' + ']'))
                        elif "street" in val:
                            val = val.replace("street", "st")
                            dic.append((normalize(val), '[' + domain + '_' + 'address' + ']'))
                    elif key == 'name':
                        dic.append((normalize(val), '[' + domain + '_' + 'name' + ']'))
                        if "b & b" in val:
                            val = val.replace("b & b", "bed and breakfast")
                            dic.append((normalize(val), '[' + domain + '_' + 'name' + ']'))
                        elif "bed and breakfast" in val:
                            val = val.replace("bed and breakfast", "b & b")
                            dic.append((normalize(val), '[' + domain + '_' + 'name' + ']'))
                        elif "hotel" in val and 'gonville' not in val:
                            val = val.replace("hotel", "")
                            dic.append((normalize(val), '[' + domain + '_' + 'name' + ']'))
                        elif "restaurant" in val:
                            val = val.replace("restaurant", "")
                            dic.append((normalize(val), '[' + domain + '_' + 'name' + ']'))
                    elif key == 'postcode':
                        dic.append((normalize(val), '[' + domain + '_' + 'postcode' + ']'))
                    elif key == 'phone':
                        dic.append((val, '[' + domain + '_' + 'phone' + ']'))
                    elif key == 'trainID':
                        dic.append((normalize(val), '[' + domain + '_' + 'id' + ']'))
                    elif key == 'department':
                        dic.append((normalize(val), '[' + domain + '_' + 'department' + ']'))

                    # NORMAL DELEX
                    elif key == 'area':
                        dic_area.append((normalize(val), '[' + 'value' + '_' + 'area' + ']'))
                    elif key == 'food':
                        dic_food.append((normalize(val), '[' + 'value' + '_' + 'food' + ']'))
                    elif key == 'pricerange':
                        dic_price.append((normalize(val), '[' + 'value' + '_' + 'pricerange' + ']'))
                    else:
                        pass
                    # TODO car type?
        except(Exception):
            pass

        if domain == 'hospital':
            dic.append((normalize('Hills Rd'), '[' + domain + '_' + 'address' + ']'))
            dic.append((normalize('Hills Road'), '[' + domain + '_' + 'address' + ']'))
            dic.append((normalize('CB20QQ'), '[' + domain + '_' + 'postcode' + ']'))
            dic.append(('01223245151', '[' + domain + '_' + 'phone' + ']'))
            dic.append(('1223245151', '[' + domain + '_' + 'phone' + ']'))
            dic.append(('0122324515', '[' + domain + '_' + 'phone' + ']'))
            dic.append((normalize('Addenbrookes Hospital'), '[' + domain + '_' + 'name' + ']'))

        elif domain == 'police':
            dic.append((normalize('Parkside'), '[' + domain + '_' + 'address' + ']'))
            dic.append((normalize('CB11JG'), '[' + domain + '_' + 'postcode' + ']'))
            dic.append(('01223358966', '[' + domain + '_' + 'phone' + ']'))
            dic.append(('1223358966', '[' + domain + '_' + 'phone' + ']'))
            dic.append((normalize('Parkside Police Station'), '[' + domain + '_' + 'name' + ']'))

    # add at the end places from trains
    fin = dbzipf.open(os.path.join('db/' + 'train' + '_db.json'), 'r')
    db_json = json.load(fin)
    fin.close()

    for ent in db_json:
        for key, val in ent.items():
            if key == 'departure' or key == 'destination':
                dic.append((normalize(val), '[' + 'value' + '_' + 'place' + ']'))

    # add specific values:
    for key in ['monday', 'tuesday', 'wednesday', 'thursday', 'friday', 'saturday', 'sunday']:
        dic.append((normalize(key), '[' + 'value' + '_' + 'day' + ']'))

    # more general values add at the end
    dic.extend(dic_area)
    dic.extend(dic_food)
    dic.extend(dic_price)

    return dic


def delexicalise(utt, dictionary):
    for key, val in dictionary:
        utt = (' ' + utt + ' ').replace(' ' + key + ' ', ' ' + val + ' ')
        utt = utt[1:-1]  # why this?

    return utt


def delexicaliseDomain(utt, dictionary, domain):
    for key, val in dictionary:
        if key == domain or key == 'value':
            utt = (' ' + utt + ' ').replace(' ' + key + ' ', ' ' + val + ' ')
            utt = utt[1:-1]  # why this?

    # go through rest of domain in case we are missing something out?
    for key, val in dictionary:
        utt = (' ' + utt + ' ').replace(' ' + key + ' ', ' ' + val + ' ')
        utt = utt[1:-1]  # why this?
    return utt


def is_ascii(s):
    return all(ord(c) < 128 for c in s)


def domain_not_empty(domain_bs):
    return any(len(val) > 0 and val not in DEFAULT_IGNORE_VALUES for val in domain_bs.values())


class BeliefStateTransformation:
    def _process_domain(self, domain_bs):
        return {self._map_slot(slot): self._clear_value(val) for slot, val in domain_bs.items()
                if (len(val) > 0 and val not in DEFAULT_IGNORE_VALUES)}

    def _map_slot(self, slot):
        if slot == 'arriveBy':
            return 'arrive by'
        if slot == 'leaveAt':
            return 'leave at'
        if slot == 'pricerange':
            slot = 'price range'
        return slot

    def _clear_value(self, value):
        value = value.replace('>', ' ')
        if value == 'el shaddia guesthouse':
            value = 'el shaddai'
        if value == 'concerthall':
            value = 'concert hall'
        if value == 'nightclub':
            value = 'night club'
        # BUG in MW2.0
        value = value.lstrip('`')
        return value

    def __call__(self, belief_state, dialogue_act, active_domain):
        clean_belief = dict()
        for domain, domain_bs in belief_state.items():
            new_domain_bs = {}
            if 'semi' in domain_bs:
                new_domain_bs.update(domain_bs['semi'])
            if 'book' in domain_bs:
                new_domain_bs.update({k: v for k, v in domain_bs['book'].items() if k != 'booked'})
            if 'book' in domain_bs and 'booked' in domain_bs['book'] and len(domain_bs['book']['booked']) > 0:
                new_domain_bs['booked'] = 'true'
            elif not domain_not_empty(domain_bs):
                continue
            new_domain_bs = self._process_domain(new_domain_bs)
            if len(new_domain_bs) == 0:
                continue
            if 'internet' in new_domain_bs and new_domain_bs['internet'] == 'no':
                del new_domain_bs['internet']  # no internet by default
            if 'parking' in new_domain_bs and new_domain_bs['parking'] == 'no':
                del new_domain_bs['parking']  # no parking by default
            clean_belief[domain] = new_domain_bs

        for domain in {'Hospital', 'Police'}:
            if any([da[1] == domain for da in dialogue_act]):
                clean_belief[domain.lower()] = {}

        # Sort belief
        clean_belief = {k: OrderedDict(sorted(v.items(), key=lambda x: x[0])) for k, v in clean_belief.items()}
        active_bs = None
        if active_domain is not None:
            active_domain = active_domain.lower()
            active_bs = clean_belief.pop(active_domain, None)
        items = [(active_domain, active_bs)] if active_bs is not None else []
        items += [(k, v) for k, v in sorted(clean_belief.items(), key=lambda x: x[0])]
        result = OrderedDict(items)
        return result


def fixDelex(delex, act):
    for k in act:
        if 'Attraction' == k[1]:
            if 'restaurant_' in delex:
                delex = delex.replace("restaurant", "attraction")
            if 'hotel_' in delex:
                delex = delex.replace("hotel", "attraction")
        if 'Hotel' == k[1]:
            if 'attraction_' in delex:
                delex = delex.replace("attraction", "hotel")
            if 'restaurant_' in delex:
                delex = delex.replace("restaurant", "hotel")
        if 'Restaurant' == k[1]:
            if 'attraction_' in delex:
                delex = delex.replace("attraction", "restaurant")
            if 'hotel_' in delex:
                delex = delex.replace("hotel", "restaurant")

    return delex


def delexicaliseReferenceNumber(sent, turn):
    """Based on the belief state, we can find reference number that
    during data gathering was created randomly."""
    domains = ['restaurant', 'hotel', 'attraction', 'train', 'taxi', 'hospital']  # , 'police']
    if turn['metadata']:
        for domain in domains:
            if turn['metadata'][domain]['book']['booked']:
                for slot in turn['metadata'][domain]['book']['booked'][0]:
                    if slot == 'reference':
                        val = '[' + domain + '_' + slot + ']'
                    else:
                        val = '[' + domain + '_' + slot + ']'
                    key = normalize(turn['metadata'][domain]['book']['booked'][0][slot])
                    sent = (' ' + sent + ' ').replace(' ' + key + ' ', ' ' + val + ' ')

                    # try reference with hashtag
                    key = normalize("#" + turn['metadata'][domain]['book']['booked'][0][slot])
                    sent = (' ' + sent + ' ').replace(' ' + key + ' ', ' ' + val + ' ')

                    # try reference with ref#
                    key = normalize("ref#" + turn['metadata'][domain]['book']['booked'][0][slot])
                    sent = (' ' + sent + ' ').replace(' ' + key + ' ', ' ' + val + ' ')
    return sent


def analyze_dialogue(dialogue, maxlen):
    """Cleaning procedure for all kinds of errors in text and annotation."""
    d = dialogue
    # do all the necessary postprocessing
    if len(d['log']) % 2 != 0:
        # print path
        logger.warning('odd # of turns')
        return None  # odd number of turns, wrong dialogue

    for i in range(len(d['log'])):
        if len(d['log'][i]['text'].split()) > maxlen:
            logger.warning('too long')
            return None  # too long sentence, wrong dialogue
        if i % 2 == 0:  # usr turn
            text = d['log'][i]['text']
            if not is_ascii(text):
                logger.warning('not ascii')
                return None
        else:  # sys turn
            if 'database' not in d['log'][i]:
                logger.warning('no db')
                return None  # no db_pointer, probably 2 usr turns in a row, wrong dialogue
            text = d['log'][i]['text']
            if not is_ascii(text):
                logger.warning('not ascii')
                return None
        d['log'][i]['text'] = clear_whitespaces(d['log'][i]['text'])
    return dialogue


def get_dial(dialogue):
    d_orig = analyze_dialogue(dialogue, MAX_LENGTH)  # max turn len is 50 words
    if d_orig is None:
        return None
    return d_orig


def createDict(word_freqs):
    words = list(word_freqs.keys())
    freqs = list(word_freqs.values())

    sorted_idx = np.argsort(freqs)
    sorted_words = [words[ii] for ii in sorted_idx[::-1]]

    # Extra vocabulary symbols
    _GO = '_GO'
    EOS = '_EOS'
    UNK = '_UNK'
    PAD = '_PAD'
    extra_tokens = [_GO, EOS, UNK, PAD]

    worddict = OrderedDict()
    for ii, ww in enumerate(extra_tokens):
        worddict[ww] = ii
    for ii, ww in enumerate(sorted_words):
        worddict[ww] = ii + len(extra_tokens)

    for key, idx in list(worddict.items()):
        if idx >= DICT_SIZE:
            del worddict[key]

    return worddict


def createDelexData(zipf, path):
    """Main function of the script - loads delexical dictionary,
    goes through each dialogue and does:
    1) data normalization
    2) delexicalization
    3) addition of database pointer
    4) saves the delexicalised data
    """
    transform_belief = BeliefStateTransformation()

    # Load databases
    with zipfile.ZipFile(os.path.join(path, 'database.zip')) as dbzipf:
        db = Database(dbzipf)

        # create dictionary of delexicalied values that then we will search against, order matters here!
        dic = prepareSlotValuesIndependent(dbzipf, path)
    delex_data = OrderedDict()
    with zipfile.ZipFile(os.path.join(path, 'lexicalizer.zip')) as lexzipf:
        lexicalizer = Lexicalizer(lexzipf)

    root = next(iter({n.strip('data.json') for n in zipf.namelist() if n.endswith('data.json')}))
    fin1 = zipf.open(root + 'data.json', 'r')
    data = json.load(fin1)

    fin2 = zipf.open(root + 'dialogue_acts.json', 'r')
    data2 = json.load(fin2)
    ignored_dialogues = 0

    for dialogue_name in tqdm(data):
        dialogue = data[dialogue_name]
        # print dialogue_name

        idx_acts = 1
        active_domain = None
        ignore_dialogue = False

        for idx, turn in enumerate(dialogue['log']):
            try:
                dialogue_act = [tuple(reversed(f.split('-'))) + tuple(x)
                                for f, xs in data2[dialogue_name.strip('.json')][str(idx_acts)].items() for x in xs]
            except(Exception):
                dialogue_act = []
            # normalization, split and delexicalization of the sentence
            sent = normalize(turn['text'])
            text = sent

            words = sent.split()
            sent = delexicalise(' '.join(words), dic)

            # parsing reference number GIVEN belief state
            sent = delexicaliseReferenceNumber(sent, turn)

            # changes to numbers only here
            digitpat = re.compile(r'\d+')
            sent = re.sub(digitpat, '[value_count]', sent)

            dialogue['log'][idx]['dialogue_act'] = dialogue_act
            dialogue['log'][idx]['speaker'] = 'user'

            # delexicalised sentence added to the dialogue
            delex = sent.strip()
            delex = fixDelex(delex, dialogue_act)

            if idx % 2 == 1:  # if it's a system turn
                dialogue['log'][idx]['speaker'] = 'system'
                belief = dialogue['log'][idx]['metadata']
                active_domain, delex = fix_active_domain_and_delex(active_domain, text, delex)
                dialogue['log'][idx]['active_domain'] = active_domain

                belief = transform_belief(belief, dialogue_act, active_domain)
                dialogue['log'][idx]['belief'] = belief
                if 'bus' in belief:
                    # We need to ignore this dialogue
                    # There is no data for the bus domain
                    ignore_dialogue = True
                    break

                dialogue['log'][idx]['database'] = db(belief)

                # Add booked property
                dialogue['log'][idx]['booked_domains'] = sorted(get_booked_domains(dialogue['log'][idx]['metadata']))

                # Test if lexicalizer works
                lexicalizer(delex, db(belief, return_results=True), belief)

            dialogue['log'][idx]['delexicalised_text'] = delex

            idx_acts += 1

        if not ignore_dialogue:
            dialogue['goal'] = parse_goal(dialogue['goal'])
            delex_data[dialogue_name] = dialogue
        else:
            ignored_dialogues += 1
    if ignored_dialogues > 0:
        logger.warning(f'dialogues were ignored {100 * ignored_dialogues / (ignored_dialogues + len(delex_data)):.1f}% due to a missing domain "bus"')  # noqa: E501
    return delex_data


def load_databases(zipf):
    dbs = {}
    sql_dbs = {'attraction', 'hotel', 'restaurant', 'train'}
    for domain in MW_DOMAINS:
        if domain in sql_dbs:
            db = 'db/{}-dbase.db'.format(domain)
            with tempfile.NamedTemporaryFile('rb+') as dbf:
                shutil.copyfileobj(zipf.open(db), dbf)
                dbf.flush()
                fileconn = sqlite3.connect(dbf.name)
                conn = sqlite3.connect(':memory:')
                fileconn.backup(conn)

            def dict_factory(cursor, row):
                d = {}
                for idx, col in enumerate(cursor.description):
                    d[col[0]] = row[idx]
                return d

            conn.row_factory = dict_factory
            c = conn.cursor()
            dbs[domain] = c
        else:
            db = 'db/{}_db.json'.format(domain)
            dbs[domain] = json.load(zipf.open(db))
    return dbs


class Database:
    def __init__(self, zipf, seed=42):
        self.path = zipf.filename
        self.dbs = load_databases(zipf)
        self.ignore_values = ['not mentioned', 'dont care', 'don\'t care', 'dontcare', 'do n\'t care', 'none']
        self.rng = random.Random(seed)

    price_re = re.compile(r'\d+\.\d+')

    @staticmethod
    def translate_to_db_col(s):
        if s == 'leave at':
            return 'leaveAt'
        elif s == 'arrive by':
            return 'arriveBy'
        elif s == 'price range':
            return 'pricerange'
        else:
            return s

    def domain_not_empty(self, domain_bs):
        return any(len(val) > 0 and val not in self.ignore_values for val in domain_bs.values())

    @staticmethod
    def map_database_key(key):
        if key == 'trainID':
            key = 'id'
        key = ''.join([' '+i.lower() if i.isupper()
                       else i for i in key]).lstrip(' ')
        key = key.replace('_', ' ')
        if key == 'pricerange':
            key = 'price range'
        if key == 'taxi phone' or key == 'phone':
            key = 'phone'
        if key == 'taxi colors':
            key = 'color'
        if key == 'taxi types':
            key = 'brand'
        if key == 'ref':
            key = 'reference'
        if key == 'leaveAt':
            key = 'leave at'
        if key == 'arriveBy':
            key = 'arrive by'
        if key == 'entrance fee':
            key = 'fee'
        return key

    @staticmethod
    def map_query_value(value):
        if value == 'concert hall':
            value = 'concerthall'
        if value == 'night club':
            value = 'nightclub'
        return value

    @staticmethod
    def capitalize(val):
        def _mk(v):
            i, v = v
            if i == 0 or v not in {'the', 'an', 'a', 'of', 'in', 'for', 'as', 'these', 'at', 'up', 'on', 'and', 'or'}:
                return v[:1].upper() + v[1:]
            else:
                return v
        return ' '.join(map(_mk, enumerate(val.split())))

    @staticmethod
    def map_database_row(domain, row, query):
        results = dict()
        for k, val in row.items():
            k2 = Database.map_database_key(k)
            if k == 'location':
                continue
            elif k == 'post code' or k == 'postcode':
                val = val.upper()
            elif k == 'name':
                val = Database.capitalize(val)
            elif k == 'type' and val == 'concerthall':
                val = 'concert hall'
            elif k == 'price' and domain == 'hotel' and isinstance(val, dict):
                val = val.get('single', val.get('double', next(iter(val.values()))))
                val = f'{val} pounds'
            if k2 == 'people':
                # BUG in MW2.0
                val = val.lstrip('`')
            results[k2] = val
        if 'color' in results and 'brand' in results:
            results['car'] = f"{results['color']} {results['brand']}"
        if domain == 'train' and 'price' in row and 'people' in query:
            people = int(query['people'])

            def multiply_people(m):
                price = float(m.group(0))
                price *= people
                return format(price, '.2f')
            if people != 1:
                results['price'] = Database.price_re.sub(multiply_people, row['price'])
        return results

    def query_domain(self, domain, query):
        # Handle special domains not in sqlite databases
        # NOTE: this is not a part of multiwoz repo
        # Taken from convlab
        if domain == 'taxi':
            return [{'color': self.rng.choice(self.dbs[domain]['taxi_colors']),
                     'brand': self.rng.choice(self.dbs[domain]['taxi_types']),
                     'phone': ''.join([str(random.randint(1, 9)) for _ in range(11)])}]
        if domain == 'police':
            return deepcopy(self.dbs['police'])
        if domain == 'hospital':
            department = None
            for key, val in query:
                if key == 'department':
                    department = val
            if not department:
                return deepcopy(self.dbs['hospital'])
            else:
                return [deepcopy(x) for x in self.dbs['hospital']
                        if x['department'].lower() == department.strip().lower()]

        sql_query = "select * from {}".format(domain)

        flag = True
        for key, val in query:
            if val == "" or val in self.ignore_values:
                pass
            else:
                if flag:
                    sql_query += " where "
                    val2 = val.replace("'", "''")
                    # change query for trains
                    if key == 'leaveAt':
                        sql_query += r" " + key + " > " + r"'" + val2 + r"'"
                    elif key == 'arriveBy':
                        sql_query += r" " + key + " < " + r"'" + val2 + r"'"
                    else:
                        sql_query += r" " + key + "=" + r"'" + val2 + r"'"
                    flag = False
                else:
                    val2 = val.replace("'", "''")
                    if key == 'leaveAt':
                        sql_query += r" and " + key + " > " + r"'" + val2 + r"'"
                    elif key == 'arriveBy':
                        sql_query += r" and " + key + " < " + r"'" + val2 + r"'"
                    else:
                        sql_query += r" and " + key + "=" + r"'" + val2 + r"'"

        result = self.dbs[domain].execute(sql_query).fetchall()
        return result

    def __call__(self, belief, return_results=False):
        all_results = OrderedDict()
        for domain, domain_bs in belief.items():
            blocked_slots = {'people', 'booked', 'stay'}
            if domain != 'train' and domain != 'bus':
                blocked_slots.add('day')
                blocked_slots.add('time')
            query = [(Database.translate_to_db_col(slot), Database.map_query_value(val))
                     for slot, val in domain_bs.items() if slot not in blocked_slots]

            result = self.query_domain(domain, query)
            result = [Database.map_database_row(domain, k, domain_bs) for k in result]
            if return_results:
                all_results[domain] = (len(result), result)
            else:
                all_results[domain] = len(result)
        return all_results

    def save(self, path):
        shutil.copy(self.path, os.path.join(path, os.path.split(self.path)[-1]))


def is_booked(raw_belief, domain):
    return domain in raw_belief and 'book' in raw_belief[domain] and \
        'booked' in raw_belief[domain]['book'] and \
        any('reference' in x for x in raw_belief[domain]['book']['booked'])


def get_booked_domains(raw_belief):
    for domain in raw_belief.keys():
        if is_booked(raw_belief, domain):
            yield domain


def parse_goal(dialog_goal):
    belief_transformation = BeliefStateTransformation()
    """Parses user goal into dictionary format."""
    goal = {}
    for domain in MW_DOMAINS:
        if not dialog_goal[domain]:
            continue
        goal[domain] = {}
        goal[domain] = {'informable': [], 'requestable': [], 'booking': {}}
        if 'info' in dialog_goal[domain]:
            # if d['goal'][domain].has_key('info'):
            if domain == 'train':
                # we consider dialogues only where train had to be booked!
                if 'book' in dialog_goal[domain]:
                    # if d['goal'][domain].has_key('book'):
                    goal[domain]['requestable'].append('reference')
                if 'reqt' in dialog_goal[domain]:
                    # if d['goal'][domain].has_key('reqt'):
                    if 'trainID' in dialog_goal[domain]['reqt']:
                        goal[domain]['requestable'].append('id')
            else:
                if 'reqt' in dialog_goal[domain]:
                    # if d['goal'][domain].has_key('reqt'):
                    for s in dialog_goal[domain]['reqt']:  # addtional requests:
                        if s in ['phone', 'address', 'postcode', 'reference', 'id']:
                            # ones that can be easily delexicalised
                            goal[domain]['requestable'].append(s)
                if 'book' in dialog_goal[domain]:
                    # if d['goal'][domain].has_key('book'):
                    goal[domain]['requestable'].append("reference")

            goal[domain]["informable"] = dialog_goal[domain]['info']
            if 'book' in dialog_goal[domain]:
                # if d['goal'][domain].has_key('book'):
                goal[domain]["booking"] = dialog_goal[domain]['book']

        if 'invalid' in goal[domain]['booking']:
            del goal[domain]['booking']['invalid']
        if 'pre_invalid' in goal[domain]['booking']:
            del goal[domain]['booking']['pre_invalid']
        belief = {domain: {'semi': goal[domain]['informable'], 'book': goal[domain]['booking']}}
        belief = belief_transformation(belief, [], domain).get(domain, dict())
        goal[domain]['informable'] = belief
        del goal[domain]['booking']
    return goal


def map_dialogue_items(log):
    supported_keys = {'text', 'delexicalised_text', 'speaker', 'belief', 'database',
                      'active_domain', 'dialogue_act', 'booked_domains'}
    for item in log:
        yield {k: v for k, v in item.items() if k in supported_keys}


def divideData(data, zipf, path):
    """Given test and validation sets, divide
    the data for three different sets"""
    testListFile = []
    root = next(iter({n.strip('data.json') for n in zipf.namelist() if n.endswith('data.json')}))
    fin = zipf.open(root + 'testListFile.json', 'r')
    for line in fin:
        testListFile.append(line[:-1].decode('utf-8'))
    fin.close()

    valListFile = []
    fin = zipf.open(root + 'valListFile.json', 'r')
    for line in fin:
        valListFile.append(line[:-1].decode('utf-8'))
    fin.close()

    test_dials = []
    val_dials = []
    train_dials = []

    for dialogue_name in tqdm(data):
        # print dialogue_name
        dial = get_dial(data[dialogue_name])
        if dial:
            dialogue = {}
            dialogue['name'] = dialogue_name
            dialogue['items'] = list(map_dialogue_items(dial['log']))
            dialogue['goal'] = dial['goal']

            if dialogue_name in testListFile:
                test_dials.append(dialogue)
            elif dialogue_name in valListFile:
                val_dials.append(dialogue)
            else:
                train_dials.append(dialogue)

    # save all dialogues
    with open(os.path.join(path, 'val.json'), 'w') as f:
        json.dump(dict(domains=MW_DOMAINS, dialogues=val_dials), f, indent=4)

    with open(os.path.join(path, 'test.json'), 'w') as f:
        json.dump(dict(domains=MW_DOMAINS, dialogues=test_dials), f, indent=4)

    with open(os.path.join(path, 'train.json'), 'w') as f:
        json.dump(dict(domains=MW_DOMAINS, dialogues=train_dials), f, indent=4)


def export_database_source(zipf):
    source_code = f"""import sqlite3
import os
import shutil
import re
import random
import json
import zipfile
import tempfile
from copy import deepcopy
from collections import OrderedDict


MW_DOMAINS = {MW_DOMAINS}


{inspect.getsource(load_databases)}

{inspect.getsource(Database)}"""
    with zipf.open('database.py', 'w') as f:
        f.write(source_code.encode('utf-8'))
        f.flush()


def download_file(source_url, dest):
    response = requests.get(source_url, stream=True, timeout=5)
    response.raise_for_status()
    file_size = int(response.headers.get('content-length', 0))
    zipf = None
    if isinstance(dest, tuple):
        zipf, dest_path = dest
    else:
        dest_path = dest
        if "/" in dest_path:
            dir = "/".join(dest_path.split("/")[0:-1])
            os.makedirs(dir, exist_ok=True)
        if os.path.exists(dest_path):
            return

    pbar = tqdm(
        total=file_size, unit='B', disable=file_size < 1024**2,
        unit_scale=True, desc=source_url.split('/')[-1])

    with tempfile.TemporaryFile('rb+') as file:
        for data in response.iter_content(chunk_size=1024):
            file.write(data)
            pbar.update(1024)
        file.flush()
        file.seek(0)
        pbar.close()
        if zipf is not None:
            with zipf.open(dest_path, 'w') as f:
                shutil.copyfileobj(file, f)
        else:
            with open(dest_path, 'wb+') as f:
                shutil.copyfileobj(file, f)


def export_lexicalizer_source(path):
    source_code = f"""from collections import defaultdict, OrderedDict
import os
import shutil
import re


{inspect.getsource(Lexicalizer)}"""
    with zipfile.ZipFile(os.path.join(path, 'lexicalizer.zip'), 'w') as zipf:
        with zipf.open('lexicalizer.py', 'w') as f:
            f.write(source_code.encode('utf-8'))
            f.flush()


def extract_databases(path, dbzipf, multiwoz_sha):
    with zipfile.ZipFile(os.path.join(path, 'database.zip'), 'w') as dboutf:
        for domain in MW_DOMAINS[:-1]:
            db = f'multiwoz-{multiwoz_sha}/db/{domain}-dbase.db'
            with dbzipf.open(db) as zf, dboutf.open(os.path.join('db', f'{domain}-dbase.db'), 'w') as f:
                shutil.copyfileobj(zf, f)

        # Fix json databases
        # Download from convlab2
        for domain in MW_DOMAINS:
            download_file(
                    f'https://raw.githubusercontent.com/thu-coai/ConvLab-2/b82732eae951b3dc957136f40b992a1904c9cbe5/data/multiwoz/db/{domain}_db.json',  # noqa: E501
                    (dboutf, os.path.join('db', f'{domain}_db.json')))

        # Export database source
        export_database_source(dboutf)


def download(version='2.0'):
    path = os.path.join(DATASETS_PATH, f'multiwoz-{version}')
    multiwoz_sha = 'a24d299fafa00371d03880bce34cb3b0923518fa'
    os.makedirs(path, exist_ok=True)
    download_file(
        f'https://github.com/budzianowski/multiwoz/raw/{multiwoz_sha}/data/MultiWOZ_{version}.zip',
        os.path.join(path, 'original.zip'))
    download_file(
        f'https://github.com/budzianowski/multiwoz/archive/{multiwoz_sha}.zip',
        os.path.join(path, 'repo.zip'))

    with zipfile.ZipFile(os.path.join(path, 'original.zip')) as zipf, \
            zipfile.ZipFile(os.path.join(path, 'repo.zip')) as dbzipf:
        export_lexicalizer_source(path)
        extract_databases(path, dbzipf, multiwoz_sha)
        delex_data = createDelexData(zipf, path)
        divideData(delex_data, zipf, path)

    # Generating blacklist
    logger.info('generating blacklist')
    cwd = os.path.dirname(os.path.abspath(__file__))
    subprocess.run(['python', os.path.join(cwd, 'build_multiwoz_blacklist.py'), '--dataset', 'multiwoz-2.0'], cwd=cwd)


if __name__ == "__main__":
    download()

2.4.2. TOD的层级结构以及各个常见的数据对象

对话系统的最小单元, 应该是什么? 是一个组件吗? 不, 组件是不统一的,多个组件才能形成一次输入. 那么是一个对话吗? 当然也不是,因为一个对话包含了多个轮. 对的, 对话系统的基本单元,是轮. 对话系统是以轮为单位进行展示的. 我们的每次训练, 所传入的每一条数据,都是1轮数据整理成了一条线状的输入. 假如全局对话如下:

class DialogueItems:
    @staticmethod
    def cumsum(sequence):
        r, s = [], 0
        for e in sequence:
            r.append(e + s)
            s += e
        return r

    def __init__(self, dialogues):

        # length list, every elemnt is the length of one dialogue
        lengths = [len(x['items']) for x in dialogues] 

        # cumulative length list, every elemnt is the cumulative length of one dialogue
        self.cumulative_sizes = DialogueItems.cumsum(lengths)
        self.dialogues = dialogues

    def __getitem__(self, idx):
        if idx < 0:
            if -idx > len(self):
                raise ValueError("absolute value of index should not exceed dataset length")
            idx = len(self) + idx


        dialogue_idx = bisect.bisect_right(self.cumulative_sizes, idx)

        if dialogue_idx == 0:
            sample_idx = idx
        else:
            sample_idx = idx - self.cumulative_sizes[dialogue_idx - 1]

        ## dialogues needed; the index of sentences in this per_dialogues.
        return self.dialogues[dialogue_idx], self.dialogues[dialogue_idx]['items'][:sample_idx + 1]

    def __len__(self):
        if not self.cumulative_sizes:
            return 0
        return self.cumulative_sizes[-1]

其中, self.cumulative_sizes 是一个列表, 列表里的每一个元素(经由cumsum计算而得)都是从一开始到现在所经历的对话轮数. 而这个函数本质上是一个dataset对象,给定一个turn的index, 他就会找到这个turn index对应的对话和从一开始一直到那一轮的片段对话, 将二者一同返回.

如此一来,我们可以意识到他是以turn为单位了.但我们对以上函数还不是特别清晰,因为我们不知道上述init函数的输入参数, 这个dialogues到底是什么. 其实这个dialogues是读取数据集json文件所获得的初始的dialogue, 所以说, 上述类的设计还不是最终的结果. 实际上,他是被当作了下一个类的输入,即下一个类中的items.

@dataclass
class DialogDataset(torch.utils.data.Dataset):
    items: List[any]
    database: Any = None
    domains: List[str] = None
    lexicalizer: Any = None
    transform: Callable[[Any], Any] = None
    normalize_input: Callable[[str], str] = None
    ontology: Dict[Tuple[str, str], Set[str]] = None

    @staticmethod
    def build_dataset_without_database(items, *args, **kwargs):
        return DialogDataset(items, FakeDatabase(), *args, **kwargs)

    def __getitem__(self, index):
        item = self.items[index]
        if self.transform is not None:
            item = self.transform(item)
        return item

    def __len__(self):
        return len(self.items)

    def map(self, transformation):
        def trans(x):
            x = self.transform(x)
            x = transformation(x)
            return x
        return dataclasses.replace(self, transform=trans)

    def finish(self, progressbar: Union[str, bool] = False):
        if self.transform is None:
            return self

        ontology = defaultdict(lambda: set())
        domains = set(self.domains) if self.domains else set()

        items = []
        for i in trange(len(self),
                        desc=progressbar if isinstance(progressbar, str) else 'loading dataset',
                        disable=not progressbar):
            item = self[i]
            for k, bs in item.raw_belief.items():
                domains.add(k)
                for k2, val in bs.items():
                    ontology[(k, k2)].add(val)
            items.append(item)
        if self.ontology:
            ontology = merge_ontologies((self.ontology, ontology))
        return dataclasses.replace(self, items=items, transform=None, domains=domains, ontology=ontology)

然后我们就看见了下一个类的函数, 我们首先注意到他是pytorch的官方Dataset类的子类了. 但我们又注意到, 这个官方类很喜欢进行transformation. 这其实是CV里用的比较多的写法, 我们注意到主要有两处,一处是 getitem 函数中, 这个transform主要用于解析和当前turn(即1个item)有关的信息,然后返回1个结构体.下面先给出transform的函数形式:

def transform(x):
    dialogue, items = x
    context = [s['text'] for s in items[:-1]]
    if context_window_size is not None and context_window_size > 0:
        context = context[-context_window_size:]
    belief = items[-1]['belief']
    database = items[-1]['database']
    dialogue_act=items[-1]["dialogue_act"]

    item = DialogDatasetItem(context,
                          raw_belief=belief,
                          raw_dialogue_act=dialogue_act,
                          database=database,
                  response=items[-1]['delexicalised_text'],
                          raw_response=items[-1]['text'])
  return item

此处所涉及到的这个类, 其实是下面的这样形式的1个对象:

@dataclass
class DialogDatasetItem:
    context: Union[List[str], str]
    belief: Union[Dict[str, Dict[str, str]], str] = None
    database: Union[List[Tuple[str, int]], List[Tuple[str, int, Any]], None, str] = None
    response: str = None
    positive: bool = True
    raw_belief: Any = None
    raw_response: str = None

    def __getattribute__(self, name):
        val = object.__getattribute__(self, name)
        if name == 'belief' and val is None and self.raw_belief is not None:
            val = format_belief(self.raw_belief)
            self.belief = val

        return val

如此一来,可以说, 所有的结构体就算是结束了.

2.4.3. 如何解析tod数据集?

该源码直接基于名字加载数据集,如下面函数所示 loader.py -> load_dataset

def load_dataset(name, restrict_domains=False, augment='disabled', use_blacklist=False, **kwargs):
    if restrict_domains:
        return load_dataset(name, domains=RESTRICTED_DOMAINS, **kwargs)

    if '+' in name:
        # This is a concat dataset
        datasets = name.split('+')
        _load_dataset = functools.partial(load_dataset, **kwargs)
        datasets = list(map(_load_dataset, datasets))
        return ConcatDialogDataset(datasets)

    dataset_name, split = split_name(name)

    from data.dataset import load_dataset as load_custom_dataset
    dataset = load_custom_dataset(name, **kwargs)

    if use_blacklist:
        dataset = add_blacklist(dataset, name)
    return dataset

其实这里面并没有加载什么数据集!真正的加载数据集的操作被写进了 dataset.py -> load_dataset 里面. 不是很理解作者的脑回路……

def load_dataset(name, use_goal=False,have_template=False,
                 context_window_size=15, domains=None,
                 **kwargs) -> DialogDataset:
    name, split = split_name(name)
    path = os.path.join(DATASETS_PATH, name)
    with open(os.path.join(path, f'{split}.json'), 'r') as f:
        data = json.load(f, object_pairs_hook=OrderedDict)
    dialogues = data['dialogues'] # load data done.
    items = DialogueItems(dialogues) # ???
    items = BlacklistItemsWrapper(items, list(build_blacklist(items, domains)))

    def transform(x):
        dialogue, items = x
        context = [s['text'] for s in items[:-1]]
        if context_window_size is not None and context_window_size > 0:
            context = context[-context_window_size:]
        belief = items[-1]['belief']
        database = items[-1]['database']
        dialogue_act=items[-1]["dialogue_act"]

        item = DialogDatasetItem(context,
                              raw_belief=belief,
                              raw_dialogue_act=dialogue_act,
                              database=database,
                      response=items[-1]['delexicalised_text'],
                              raw_response=items[-1]['text'])

        if use_goal:
            setattr(item, 'goal', dialogue['goal'])
            # MultiWOZ evaluation uses booked domains property
            if 'booked_domains' in items[-1]:
                setattr(item, 'booked_domains', items[-1]['booked_domains'])
            setattr(item, 'dialogue_act', items[-1]['dialogue_act'])
        setattr(item, 'active_domain', items[-1]['active_domain'])
        return item

    dataset = DialogDataset(items, transform=transform, domains=data['domains'])
    if os.path.exists(os.path.join(path, 'database.zip')):
        dataset.database = AutoDatabase.load(path)

    if os.path.exists(os.path.join(path, 'lexicalizer.zip')):
        dataset.lexicalizer = AutoLexicalizer.load(path)

    return dataset

这里就是解析数据集的关键了. 这个代码看起来有一些的复杂, 主要包括以下步骤:

  1. 首先, 基于一个数据集的名字, 经过一些变换得到数据集的路径,并且加载得到数据集.
  2. 之后, 加载数据集并进行解析操作. 数据集一共经历了三层封装.: 不得不说,这个写法也是有点绕的. 在进行数据的解析过程中,最值得关注的点,其实还是那个transform函数.根据源码可知,这个transform函数被作用在了每一个item上,而每一个item就是一个 DialogueItems 列表中的一个元素,所以其实就是说: 被进行transform的,其实就是这个对话,和对话的前n行. 然后这个transform最终的目标,或者transform的返回值, 则是一个 DialogueDatasetItem 对象,亦即是最终会被用来作为输入的一个batch单元.

换而言之, 上述各类的整体流程是: DialogueItems -> DialogDataset -> DialogeDatasetItem. 我是不理解为什么这个名字起的为什么这么没有分辨性,甚至我怀疑定义这么多class的重要性在哪里.

总之通过这种方式, 我们得到了1 标准的pytorch Dataset, 其返回的每条数据是一个DialogueDatasetItem, 但是竟然这个代码又进行了一次映射. 如下所示:

值得注意的是, 我们的DialogDataset, 在这里就是inner. 你会发现这里又出现了新的transform, 这个transform就是在前面所提及的transform了,也就是 InsertLabelstransformationTokenizerTransformation 的组合.

class NegativeSamplingDatasetWrapper(torch.utils.data.Dataset):
    def __init__(self, inner, transform=None,num_bs_negative=4):
        ## inner --> datasets
        self.inner = inner
        self.transform = transform
        assert hasattr(self.inner, 'ontology')
        assert self.inner.ontology is not None
        self.ontology = {k: sorted(v) for k, v in self.inner.ontology.items()}
        self.num_bs_negative=num_bs_negative
    def __len__(self):
        return 2 * len(self.inner)

    def __getitem__(self, i):
        item = self.inner[i // 2]

        negtive_bs_list=[]
        # random sampling for belief state.
        for x in range(self.num_bs_negative):
            negative_sample=random.randrange(len(self.inner))
            neg_sample=self.inner[negative_sample]
            negtive_bs_list.append(format_belief(neg_sample.raw_belief))

        negative = i % 2
        if negative:
            negative = False
            belief, response, context = item.belief, item.response, item.context
            raw_belief = item.raw_belief
            negative_type = random.randrange(1, 4)
            use_new_belief = (negative_type // 2) % 2
            use_new_response = negative_type % 2

            # Negative resonse
            negative_sample = random.randrange(len(self.inner))
            neg_sample = self.inner[negative_sample]

            if use_new_belief:
                raw_belief = neg_sample.raw_belief
            if use_new_response:
                response = neg_sample.response
            belief = format_belief(raw_belief)
            item = dataclasses.replace(item, context=context,
                                       belief=belief,
                                       raw_belief=raw_belief,
                                       response=response,
                                       positive=False)

        item = dataclasses.replace(item,
                                   negative_bs_list=negtive_bs_list)
        return self.transform(item)

简单来说,这个wrapper就是加入了一个负采样, 相当于会生成1个正样本生成1个负样本, 这是为了一致性任务而使用的,后续会进行介绍.

如此一来,整体上的data的处理就算是完成了. 下面来关注一下model的设计.

3. Model的设计

model的设计中, training的forward和inference时的generation其实是两种形态,因此在此处进行两种描述. 先看一下训练时.

3.1. train- forward

核心代码如下:

class SoloistConfig(transformers.GPT2Config):
    def __init__(self,
                 summary_label_smoothing=0.1, # for overfitting.
                 **kwargs):
        super().__init__(**kwargs)
        self.summary_label_smoothing = summary_label_smoothing

class SoloistModel(transformers.GPT2PreTrainedModel):
    authorized_missing_keys = [r"h\.\d+\.attn\.masked_bias",
                               r"lm\_head\.weight", r"binary\_head\.\w+"]

    def __init__(self, config):
        super().__init__(config)
        self.transformer = transformers.GPT2Model(config)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        self.consistency_head = nn.Linear(config.n_embd, 1) # ?
        self.auxiliary_dropout = nn.Dropout(config.summary_first_dropout)
        self.init_weights()

    def get_output_embeddings(self):
        return self.lm_head

    def forward(self,
                input_ids=None,          # all input sequence tokens;
                past=None,
                attention_mask=None,
                token_type_ids=None,
                position_ids=None,
                head_mask=None,
                inputs_embeds=None,
                consistency_token_ids=None, # the last token (eos), for classify whether consistent or not. 
                consistency_labels=None,  # is consistency or not
                user_intent_token_ids=None,
                user_intent_labels=None,
                user_intent_mask=None,
                belief_labels=None,      # context + belief states, and aims to predict bs.
                system_action_token_ids=None,
                system_action_labels=None,
                system_action_mask=None,
                response_labels=None, # only responses part has label, and others part is -100.
                back_predict_labels=None,  # this is the target of back predicted resonses.
                bp_weight=0.3,
                binary_labels=None,
                use_cache=None,
                output_attentions=None,
                output_hidden_states=None,
                **kwargs  # context=context_labels
                ):
        # print(f"shape of bp weight: {bp_weight.shape}")

        transformer_outputs = self.transformer(
            input_ids,
            past=past,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
        )

        hidden_states = transformer_outputs[0]
        lm_logits = self.lm_head(hidden_states)

        def gather_auxiliary_features(token_ids):
            if token_ids is None:
                # torch.full_like(input,fill_value) returns the same size
                #as input with the filling of fill_value.

                # hidden_states.shape[-2] is max-seqence-length

                # ... in split means select the last dimension, so it means select the last
                # embedding dimension, and select the first batch, and with all seqence.
                # so the shape is 1*msl*1
                token_ids = torch.full_like(hidden_states[..., :1, :], # which the shape is ???
                                            hidden_states.shape[-2]-1, dtype=torch.long,)
            else:
                token_ids = token_ids.unsqueeze(-1).unsqueeze(-1)
                token_ids = token_ids.expand(
                    (-1,) * (token_ids.dim() - 1) + (hidden_states.size(-1),))

            # shape of binary_token_ids: (bsz, XX, 1, hidden_size)
            # where XX are optional leading dim of hidden_states
            # shape of binary_logits (bsz, XX, hidden_size)
            logits = hidden_states.gather(-2, token_ids).squeeze(-2)
            logits = self.auxiliary_dropout(logits)
            return logits

        consistency_logits = self.consistency_head(gather_auxiliary_features(consistency_token_ids)).squeeze(-1)
        consistency_loss = None
        if consistency_labels is not None:
            # Auxiliary tasks
            aux_criterion = LabelSmoothingBCEWithLogitsLoss(self.config.summary_label_smoothing)
            consistency_loss = aux_criterion(consistency_logits, consistency_labels)

        belief_loss, response_loss = None, None
        if belief_labels is not None:
            assert response_labels is not None

            shift_logits = lm_logits[..., :-1, :].contiguous()
            shift_belief_labels = belief_labels[..., 1:].contiguous()
            shift_response_labels = response_labels[..., 1:].contiguous()
            loss_fct = nn.CrossEntropyLoss()
            belief_loss = loss_fct(
                shift_logits.view(-1, shift_logits.size(-1)),
                shift_belief_labels.view(-1))

            response_ce = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_response_labels.view(-1))
            response_loss = response_ce

            bp_loss=0.
            ## we only use 0.5 weighted bp losses.
            if back_predict_labels is not None:
                shift_bp_labels = back_predict_labels[..., 1:].contiguous()
                bp_loss = loss_fct(
                shift_logits.view(-1, shift_logits.size(-1)),
                shift_bp_labels.view(-1))
                # assert bp_weight is not None
                bp_loss*=bp_weight[0]

        output = (lm_logits, consistency_logits,) + transformer_outputs[1:]
        if consistency_loss is not None:
            output = (consistency_loss,) + output
        return ((belief_loss, response_loss + bp_loss, response_ce + bp_loss) + output) if belief_loss is not None else output

4. 训练流程

5. 推理流程


Author: Zi Liang (liangzid@stu.xjtu.edu.cn) Create Date: Sun Jan 16 15:43:35 2022 Last modified: 2024-03-09 Sat 20:56 Creator: Emacs 28.1 (Org mode 9.5.2)