HuggingFace镜像/Llama-3.1-8B-Medusa-FP8
模型介绍文件和版本分析
下载使用量0

模型概述

描述:

NVIDIA Llama 3.1 8B Medusa FP8 模型是 Meta Llama 3.1 8B Instruct 模型的量化版和 Medusa 增强版,后者是一款采用优化 Transformer 架构的自回归语言模型。它是一款经过指令微调的生成式模型(文本输入/文本输出)。欲了解更多信息,请查看此处。

NVIDIA Llama 3.1 8B Medusa FP8 模型通过 Medusa 投机解码进行了增强,并使用 TensorRT Model Optimizer 进行了量化。

本模型可用于商业和非商业用途。

第三方社区考量:

本模型并非由 NVIDIA 拥有或开发。本模型是应第三方的特定应用和使用场景要求而开发构建的;请参见非 NVIDIA(Meta-Llama-3.1-8B-Instruct)模型卡片链接 (Meta-Llama-3.1-8B-Instruct) Model Card。

许可/使用条款:

管辖条款:使用本模型受 NVIDIA Open Models License 约束。补充信息:Llama 3.1 Community License Agreement。基于 Meta Llama 3.1 构建。

模型架构:

架构类型: Transformer
网络架构: Llama3.1

输入:

输入类型: 文本
输入格式: 字符串
输入参数: 1D;序列
与输入相关的其他属性: 上下文长度可达 128K

输出:

输出类型: 文本
输出格式: 字符串
输出参数: 1D;序列

软件集成

支持的运行时引擎:

  • Tensor(RT)-LLM

支持的硬件微架构兼容性:

  • NVIDIA Blackwell
  • NVIDIA Hopper
  • NVIDIA Lovelace

[首选/支持的] 操作系统:

  • Linux

模型版本:

v0.23.0

训练与评估数据集:

训练数据集:

链接:Daring-Anteater,用于数据合成,进而训练 Medusa 头。有关该数据集的更多信息,请参见此处。
数据集的数据收集方法

  • [自动化]
    数据集的标注方法
  • 合成
    属性:人工合成数据集,100K 行数据。

链接:cnn_dailymail,用于校准。有关该数据集的更多信息,请参见此处。
数据集的数据收集方法

  • 未知
    数据集的标注方法
  • 人工

评估数据集:

链接:MMLU,更多详情请参见此处
数据集的数据收集方法

  • [人工]
    数据集的标注方法
  • [人工]

Medusa 推测式解码与训练后量化

合成数据来源于 Meta-Llama-3.1-8B-Instruct 的 FP8 量化版本,用于微调 Medusa 头。本模型通过将 Meta-Llama-3.1-8B-Instruct 与 Medusa 头的权重和激活量化为 FP8 数据类型而获得,可在 Medusa 推测式解码模式下通过 TensorRT-LLM 进行推理。仅对 transformer 块和 Medusa 头内线性算子的权重和激活进行量化。此优化将每个参数的位数从 16 位减少到 8 位,磁盘大小和 GPU 内存需求降低约 50%。

Medusa 头用于预测下一个 token 之外的候选 token。在生成步骤中,每个 Medusa 头会基于前序 token 生成后续 token 的分布。然后,基于树的注意力机制会采样部分候选序列供原始模型验证。选择最长的被接受候选序列,以便在生成步骤中返回多个 token。每步生成的 token 数量称为接受率。

使用方法

若要通过 TensorRT-LLM(自 v0.17 版本起支持)运行推理,建议使用 LLM API,可参考 此示例,执行命令 python llm_medusa_decoding.py --use_modelopt_ckpt 或以下命令。LLM API 已将 checkpoint 转换、引擎构建和推理等步骤进行了抽象封装。

### Generate Text Using Medusa Decoding

from tensorrt_llm import LLM, SamplingParams
from tensorrt_llm.llmapi import (LLM, BuildConfig,
                                 MedusaDecodingConfig, SamplingParams)
from tensorrt_llm.models.modeling_utils import SpeculativeDecodingMode


def main():
    # Sample prompts.
    prompts = [
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
        "The future of AI is",
    ]
    # The end user can customize the sampling configuration with the SamplingParams class
    sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

    # The end user can customize the build configuration with the BuildConfig class
    build_config = BuildConfig(
        max_batch_size=1,
        max_seq_len=1024,
        max_draft_len=63,
        speculative_decoding_mode=SpeculativeDecodingMode.MEDUSA)

    # The end user can customize the medusa decoding configuration by specifying the
    # medusa heads num and medusa choices with the MedusaDecodingConfig class
    speculative_config = MedusaDecodingConfig(num_medusa_heads=3,
                            medusa_choices=[[0], [0, 0], [1], [0, 1], [2], [0, 0, 0], [1, 0], [0, 2], [3], [0, 3], \
                                [4], [0, 4], [2, 0], [0, 5], [0, 0, 1], [5], [0, 6], [6], [0, 7], [0, 1, 0], [1, 1], \
                                    [7], [0, 8], [0, 0, 2], [3, 0], [0, 9], [8], [9], [1, 0, 0], [0, 2, 0], [1, 2], [0, 0, 3], \
                                        [4, 0], [2, 1], [0, 0, 4], [0, 0, 5], [0, 1, 1], [0, 0, 6], [0, 3, 0], [5, 0], [1, 3], [0, 0, 7], \
                                            [0, 0, 8], [0, 0, 9], [6, 0], [0, 4, 0], [1, 4], [7, 0], [0, 1, 2], [2, 0, 0], [3, 1], [2, 2], [8, 0], [0, 5, 0], [1, 5], [1, 0, 1], [0, 2, 1], [9, 0], [0, 6, 0], [1, 6], [0, 7, 0]]
      )
    llm = LLM(model="nvidia/Llama-3.1-8B-Medusa-FP8",
              build_config=build_config,
              speculative_config=speculative_config)

    outputs = llm.generate(prompts, sampling_params)

    # Print the outputs.
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")


if __name__ == '__main__':
    main()

或者,您可以参考 TensorRT-LLM GitHub 仓库中的 Medusa 解码示例 CLI。 TensorRT-LLM 基准测试 对 trtllm-bench 的支持即将推出。

评估

下表展示了准确率(MMLU,5-shot)和 Medusa 接受率的基准测试结果:

精度MMLUMT Bench 接受率
FP868.32.07

推理:

引擎: Tensor(RT)-LLM
测试硬件: H100

伦理考量

NVIDIA 认为可信 AI 是一项共同责任,我们已制定相关政策和实践,以支持广泛的 AI 应用开发。当开发者按照我们的服务条款下载或使用本模型时,应与内部模型团队合作,确保该模型满足相关行业和用例的要求,并应对未预见的产品误用问题。

请通过 此处 报告安全漏洞或 NVIDIA AI 相关问题。