NVIDIA Llama 3.1 8B Medusa FP8 模型是 Meta Llama 3.1 8B Instruct 模型的量化版和 Medusa 增强版,后者是一款采用优化 Transformer 架构的自回归语言模型。它是一款经过指令微调的生成式模型(文本输入/文本输出)。欲了解更多信息,请查看此处。
NVIDIA Llama 3.1 8B Medusa FP8 模型通过 Medusa 投机解码进行了增强,并使用 TensorRT Model Optimizer 进行了量化。
本模型可用于商业和非商业用途。
本模型并非由 NVIDIA 拥有或开发。本模型是应第三方的特定应用和使用场景要求而开发构建的;请参见非 NVIDIA(Meta-Llama-3.1-8B-Instruct)模型卡片链接 (Meta-Llama-3.1-8B-Instruct) Model Card。
管辖条款:使用本模型受 NVIDIA Open Models License 约束。补充信息:Llama 3.1 Community License Agreement。基于 Meta Llama 3.1 构建。
架构类型: Transformer
网络架构: Llama3.1
输入类型: 文本
输入格式: 字符串
输入参数: 1D;序列
与输入相关的其他属性: 上下文长度可达 128K
输出类型: 文本
输出格式: 字符串
输出参数: 1D;序列
支持的运行时引擎:
支持的硬件微架构兼容性:
[首选/支持的] 操作系统:
v0.23.0
链接:Daring-Anteater,用于数据合成,进而训练 Medusa 头。有关该数据集的更多信息,请参见此处。
数据集的数据收集方法
链接:cnn_dailymail,用于校准。有关该数据集的更多信息,请参见此处。
数据集的数据收集方法
链接:MMLU,更多详情请参见此处
数据集的数据收集方法
合成数据来源于 Meta-Llama-3.1-8B-Instruct 的 FP8 量化版本,用于微调 Medusa 头。本模型通过将 Meta-Llama-3.1-8B-Instruct 与 Medusa 头的权重和激活量化为 FP8 数据类型而获得,可在 Medusa 推测式解码模式下通过 TensorRT-LLM 进行推理。仅对 transformer 块和 Medusa 头内线性算子的权重和激活进行量化。此优化将每个参数的位数从 16 位减少到 8 位,磁盘大小和 GPU 内存需求降低约 50%。
Medusa 头用于预测下一个 token 之外的候选 token。在生成步骤中,每个 Medusa 头会基于前序 token 生成后续 token 的分布。然后,基于树的注意力机制会采样部分候选序列供原始模型验证。选择最长的被接受候选序列,以便在生成步骤中返回多个 token。每步生成的 token 数量称为接受率。
若要通过 TensorRT-LLM(自 v0.17 版本起支持)运行推理,建议使用 LLM API,可参考 此示例,执行命令 python llm_medusa_decoding.py --use_modelopt_ckpt 或以下命令。LLM API 已将 checkpoint 转换、引擎构建和推理等步骤进行了抽象封装。
### Generate Text Using Medusa Decoding
from tensorrt_llm import LLM, SamplingParams
from tensorrt_llm.llmapi import (LLM, BuildConfig,
MedusaDecodingConfig, SamplingParams)
from tensorrt_llm.models.modeling_utils import SpeculativeDecodingMode
def main():
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# The end user can customize the sampling configuration with the SamplingParams class
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
# The end user can customize the build configuration with the BuildConfig class
build_config = BuildConfig(
max_batch_size=1,
max_seq_len=1024,
max_draft_len=63,
speculative_decoding_mode=SpeculativeDecodingMode.MEDUSA)
# The end user can customize the medusa decoding configuration by specifying the
# medusa heads num and medusa choices with the MedusaDecodingConfig class
speculative_config = MedusaDecodingConfig(num_medusa_heads=3,
medusa_choices=[[0], [0, 0], [1], [0, 1], [2], [0, 0, 0], [1, 0], [0, 2], [3], [0, 3], \
[4], [0, 4], [2, 0], [0, 5], [0, 0, 1], [5], [0, 6], [6], [0, 7], [0, 1, 0], [1, 1], \
[7], [0, 8], [0, 0, 2], [3, 0], [0, 9], [8], [9], [1, 0, 0], [0, 2, 0], [1, 2], [0, 0, 3], \
[4, 0], [2, 1], [0, 0, 4], [0, 0, 5], [0, 1, 1], [0, 0, 6], [0, 3, 0], [5, 0], [1, 3], [0, 0, 7], \
[0, 0, 8], [0, 0, 9], [6, 0], [0, 4, 0], [1, 4], [7, 0], [0, 1, 2], [2, 0, 0], [3, 1], [2, 2], [8, 0], [0, 5, 0], [1, 5], [1, 0, 1], [0, 2, 1], [9, 0], [0, 6, 0], [1, 6], [0, 7, 0]]
)
llm = LLM(model="nvidia/Llama-3.1-8B-Medusa-FP8",
build_config=build_config,
speculative_config=speculative_config)
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
if __name__ == '__main__':
main()
或者,您可以参考 TensorRT-LLM GitHub 仓库中的 Medusa 解码示例 CLI。
TensorRT-LLM 基准测试 对 trtllm-bench 的支持即将推出。
下表展示了准确率(MMLU,5-shot)和 Medusa 接受率的基准测试结果:
引擎: Tensor(RT)-LLM
测试硬件: H100
NVIDIA 认为可信 AI 是一项共同责任,我们已制定相关政策和实践,以支持广泛的 AI 应用开发。当开发者按照我们的服务条款下载或使用本模型时,应与内部模型团队合作,确保该模型满足相关行业和用例的要求,并应对未预见的产品误用问题。
请通过 此处 报告安全漏洞或 NVIDIA AI 相关问题。