HuggingFace镜像/e5-base-4k
模型介绍文件和版本分析

E5-base-4k

LongEmbed: Extending Embedding Models for Long Context Retrieval。Dawei Zhu、Liang Wang、Nan Yang、Yifan Song、Wenhao Wu、Furu Wei、Sujian Li,arxiv 2024。LongEmbed 的 Github 代码库:https://github.com/dwzhu-pku/LongEmbed。

该模型包含 12 层,嵌入维度为 768。

使用方法

以下是对 MS-MARCO 段落排序数据集中的查询和段落进行编码的示例。

from openmind import AutoTokenizer, AutoModel, is_torch_npu_available
import torch.nn.functional as F
import torch
from torch import Tensor
import argparse

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_name_or_path",
        type=str,
        help="Path to model",
        default="ChongqingAscend/e5-base-4k",
    )
    args = parser.parse_args()
    return args

def average_pool(last_hidden_states: Tensor,
                 attention_mask: Tensor) -> Tensor:
    last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
    return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]

def get_position_ids(input_ids: Tensor, max_original_positions: int=512, encode_max_length: int=4096) -> Tensor:

    position_ids = list(range(input_ids.size(1)))
    factor = max(encode_max_length // max_original_positions, 1)
    if input_ids.size(1) <= max_original_positions:
        position_ids = [(pid * factor) for pid in position_ids]
        
    position_ids = torch.tensor(position_ids, dtype=torch.long)
    position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
    
    return position_ids

def main():
    args = parse_args()
    model_path = args.model_name_or_path

    if is_torch_npu_available():
        device = "npu:0"
    else:
        device = "cpu"
        
    # Each input text should start with "query: " or "passage: ".
    # For tasks other than retrieval, you can simply use the "query: " prefix.
    input_texts = ['query: how much protein should a female eat',
                   'query: summit define',
                   "passage: As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.",
                   "passage: Definition of summit for English Language Learners. : 1  the highest point of a mountain : the top of a mountain. : 2  the highest level. : 3  a meeting or series of meetings between the leaders of two or more governments."]

    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModel.from_pretrained(model_path).to(device)
    model = model.eval()

    # Tokenize the input texts
    batch_dict = tokenizer(input_texts, max_length=512, padding=True, truncation=True, return_tensors='pt').to(device)
    batch_dict['position_ids'] = get_position_ids(batch_dict['input_ids'], max_original_positions=512, encode_max_length=4096)

    outputs = model(**batch_dict)
    embeddings = average_pool(outputs.last_hidden_state, batch_dict['attention_mask'])

    # normalize embeddings
    embeddings = F.normalize(embeddings, p=2, dim=1)
    scores = (embeddings[:2] @ embeddings[2:].T) * 100
    print(scores.tolist())
    
if __name__ == "__main__":
    main()

训练详情

请参考我们的论文:https://arxiv.org/abs/2404.12096.pdf。请注意,E5-Base-4k 只是扩展了位置嵌入矩阵,以支持 4,096 个位置 ID。原始位置 ID {0,1,2,...,511} 对应的嵌入向量被映射为表示 {0,8,16,...,4088}。其他位置 ID 的嵌入向量则通过训练得到。因此,对于不超过 512 个 Token 的输入,请将位置 ID 乘以 8,以保持原始行为,如上述代码所示。

基准测试评估

请访问 unilm/e5 以复现在 BEIR 和 MTEB 基准 上的评估结果。

引用

如果您觉得我们的论文或模型对您有帮助,请考虑按以下方式引用:

@article{zhu2024longembed,
  title={LongEmbed: Extending Embedding Models for Long Context Retrieval},
  author={Zhu, Dawei and Wang, Liang and Yang, Nan and Song, Yifan and Wu, Wenhao and Wei, Furu and Li, Sujian},
  journal={arXiv preprint arXiv:2404.12096},
  year={2024}
}
下载使用量0