HuggingFace镜像/colbertv2.0
模型介绍文件和版本分析

ColBERT (v2)

ColBERT 是一款 快速 且 精准 的检索模型,能够在数十毫秒内对大型文本集合进行基于 BERT 的可扩展搜索。

图 1:ColBERT 的晚期交互机制,可高效计算查询与段落之间细粒度的相似度。

如图 1 所示,ColBERT 依赖于细粒度的 上下文晚期交互:它将每个段落编码为一个 矩阵 形式的 token 级嵌入(上图中蓝色部分所示)。然后在搜索时,它将每个查询嵌入为另一个矩阵(绿色部分所示),并使用可扩展的向量相似度(MaxSim)运算符高效地找到与查询在上下文上匹配的段落。

这种丰富的交互机制使 ColBERT 在质量上超越了 单向量 表示模型,同时能够高效扩展到大型语料库。您可以在我们的论文中了解更多信息:

  • ColBERT:通过 BERT 上的上下文晚期交互实现高效且有效的段落搜索(SIGIR'20)。
  • 基于 ColBERT 的开放域问答相关性引导监督(TACL'21)。
  • Baleen:通过压缩检索实现大规模鲁棒多跳推理(NeurIPS'21)。
  • ColBERTv2:通过轻量级晚期交互实现高效且有效的检索(NAACL'22)。
  • PLAID:晚期交互检索的高效引擎(CIKM'22)。

ColBERTv1

SIGIR'20 论文中的 ColBERTv1 代码位于 [colbertv1 分支]。有关其他分支的更多信息,请参见此处。

概述

在数据集上使用 ColBERT 通常包括以下步骤。

步骤 0:预处理您的集合。 简单来说,ColBERT 适用于制表符分隔(TSV)文件:一个文件(例如 collection.tsv)将包含所有段落,另一个文件(例如 queries.tsv)将包含一组用于搜索该集合的查询。

步骤 1:下载 [预训练的 ColBERTv2 检查点]。 此检查点已在 MS MARCO 段落排序任务上进行了训练。您也可以选择训练自己的 ColBERT 模型。

步骤 2:为您的集合建立索引。 拥有训练好的 ColBERT 模型后,您需要为集合建立索引以实现快速检索。此步骤将所有段落编码为矩阵,存储在磁盘上,并构建用于高效搜索的数据结构。

步骤 3:使用您的查询搜索集合。 有了模型和索引,您可以在集合上发出查询,为每个查询检索前 k 个段落。

下面,我们通过在 MS MARCO 段落排序任务上的示例运行来说明这些步骤。

如何使用

import argparse
from openmind import AutoModel, AutoTokenizer
from openmind import is_torch_npu_available

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name_or_path",type=str,help="Path to model",default=None,)
    args = parser.parse_args()
    return args


if __name__ == '__main__':
    if is_torch_npu_available():
        device = "npu:0"
    else:
        device = "cpu"
    args = parse_args()
    model_path = args.model_name_or_path

    # Note: CodeSage requires adding eos token at the end of
    # each tokenized sequence to ensure good performance
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, add_eos_token=True)

    model = AutoModel.from_pretrained(model_path, trust_remote_code=True).to(device)

    inputs = tokenizer.encode("def print_hello_world():\tprint('Hello World!')", return_tensors="pt").to(device)
    embedding = model(inputs)[0]
    print(f'Dimension of the embedding: {embedding[0].size()}')
    print(embedding)
下载使用量0