E
Eco-Tech/Step-3.5-Flash-w8a8-mtp
模型介绍文件和版本Pull Requests讨论分析
下载使用量0

Step-3.5-Flash-w8a8-mtp

1. 基本信息

项目信息
原始模型名Step-3.5-Flash
原始模型链接stepfun-ai/Step-3.5-Flash
msmodelslim commit id758efe6c091ad3dcc40eb3e7b22d46c3698d3d46
精度测试机型Atlas 800I A2 1台
精度测试平台docker vllm-ascend
版本vllm-ascend:v0.17.0rc1
链接quay.io/ascend/vllm-ascend:v0.17.0rc1

2 量化步骤:

2.1 量化脚本及配置文件

量化脚本:

msmodelslim quant --trust_remote_code True\
    --model_path {浮点权重路径} \
    --save_path {W8A8量化权重路径} \
    --device cpu \
    --model_type Step-3.5-Flash \
    --config_path {yaml配置文件路径}

配置文件:

# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
apiversion: modelslim_v1
metadata:
  config_id: Step-3.5-Flash_w8a8_mxfp8
  score: 90
  verified_model_types:
    - Step-3.5-Flash
  label:
    w_bit: 8
    a_bit: 8
    is_sparse: False
    kv_cache: False

default_w8a8_dynamic: &default_w8a8_dynamic
  act:
    scope: "per_token"
    dtype: "int8"
    symmetric: True
    method: "minmax"
  weight:
    scope: "per_channel"
    dtype: "int8"
    symmetric: True
    method: "minmax"

spec:
  process:
    - type: "linear_quant"
      qconfig: *default_w8a8_dynamic
      include: 
        - "*moe.experts*"
  save:
    - type: "ascendv1_saver"
      part_file_size: 4
  dataset: "data_list_1.jsonl"  # Short name: auto-searches in lab_calib/
  default_text: "Describe this image in detail."

2.2 复制mtp层权重并修改description文件

copy_mtp_layers.py:

import os
import json
import shutil
from pathlib import Path
from safetensors.torch import load_file, save_file
from collections import defaultdict

# ================= 配置区域 =================
# 原始 BF16 权重目录
ORIG_WEIGHT_DIR = Path("/path/to/Step-3.5-Flash/")
# 原始权重索引文件
ORIG_INDEX_FILE = ORIG_WEIGHT_DIR / "model.safetensors.index.json"

# 量化后 W8A8 权重目录
QUANT_WEIGHT_DIR = Path("/path/to/Step-3.5-Flash-w8a8/")
# 量化权重索引文件
QUANT_INDEX_FILE = QUANT_WEIGHT_DIR / "quant_model_weights.safetensors.index.json"

# 输出目录 (新目录,不会覆盖原量化权重)
OUTPUT_PATH = Path("/path/to/Step-3.5-Flash-w8a8-with-mtp/")

# 需要保留为 BF16 的 MTP 层号
MTP_LAYERS = ["45", "46", "47"]
# ===========================================

def main():
    print(f"--- 开始处理 ---")
    print(f"原始权重目录:{ORIG_WEIGHT_DIR}")
    print(f"量化权重目录:{QUANT_WEIGHT_DIR}")
    print(f"输出目录:{OUTPUT_PATH}")

    # 1. 创建输出目录并复制量化权重文件
    if OUTPUT_PATH.exists():
        print(f"警告:输出目录 {OUTPUT_PATH} 已存在,将覆盖其中的内容。")
        shutil.rmtree(OUTPUT_PATH)
    
    OUTPUT_PATH.mkdir(parents=True, exist_ok=True)
    print(f"正在复制量化权重文件到新目录...")
    # 复制所有文件
    for item in QUANT_WEIGHT_DIR.iterdir():
        if item.is_file():
            shutil.copy2(item, OUTPUT_PATH / item.name)
    print(f"量化文件复制完成。")

    # 2. 解析原始权重索引,找到 MTP 层所在的文件
    if not ORIG_INDEX_FILE.exists():
        raise FileNotFoundError(f"未找到原始权重索引文件:{ORIG_INDEX_FILE}")
    
    with open(ORIG_INDEX_FILE, 'r', encoding='utf-8') as f:
        orig_index = json.load(f)
    
    weight_map = orig_index.get("weight_map", {})
    
    # 收集 MTP 层的权重键值对,并按源文件分组
    # 结构:{ "model-00002.safetensors": ["model.layers.45...", "model.layers.46..."] }
    mtp_keys_by_file = defaultdict(list)
    mtp_all_keys = []
    
    prefix_patterns = [f"model.layers.{i}." for i in MTP_LAYERS]
    
    for key, filename in weight_map.items():
        is_mtp = False
        for prefix in prefix_patterns:
            if key.startswith(prefix):
                is_mtp = True
                break
        
        if is_mtp:
            mtp_keys_by_file[filename].append(key)
            mtp_all_keys.append(key)
            
    if not mtp_all_keys:
        print("错误:未在原始权重中找到指定的 MTP 层 (45, 46, 47)。请检查层号或索引文件。")
        return

    print(f"找到 {len(mtp_all_keys)} 个 MTP 层权重参数。")
    print(f"分布在以下文件中:{list(mtp_keys_by_file.keys())}")

    # 3. 从原始文件中提取权重并合并
    mtp_tensors = {}
    
    # 原始权重目录下的文件路径
    orig_safetensors_dir = ORIG_WEIGHT_DIR 
    
    for src_filename, keys in mtp_keys_by_file.items():
        src_path = orig_safetensors_dir / src_filename
        if not src_path.exists():
            raise FileNotFoundError(f"原始权重文件不存在:{src_path}")
        
        print(f"正在从 {src_filename} 加载 MTP 权重...")
        tensors = load_file(src_path)
        
        for key in keys:
            if key in tensors:
                mtp_tensors[key] = tensors[key]
            else:
                print(f"警告:在文件 {src_filename} 中未找到键 {key}")
    
    # 4. 保存 MTP 权重到新文件
    mtp_output_filename = "mtp_layers.safetensors"
    mtp_output_path = OUTPUT_PATH / mtp_output_filename
    
    print(f"正在保存 MTP 权重到 {mtp_output_filename} (BF16 格式)...")
    save_file(mtp_tensors, mtp_output_path)
    print(f"MTP 权重保存完成。")

    # 5. 更新新目录下的量化索引文件
    new_index_path = OUTPUT_PATH / "quant_model_weights.safetensors.index.json"
    
    with open(new_index_path, 'r', encoding='utf-8') as f:
        new_index = json.load(f)
    
    # 更新 weight_map
    current_weight_map = new_index.get("weight_map", {})
    for key in mtp_all_keys:
        current_weight_map[key] = mtp_output_filename
    
    new_index["weight_map"] = current_weight_map
    
    # 更新 metadata 中的 total_size 
    if "metadata" in new_index and "total_size" in new_index["metadata"]:
        # 由于我们添加了 BF16 权重,总大小会增加。
        # 简单计算新增大小
        added_size = sum(t.element_size() * t.nelement() for t in mtp_tensors.values())
        new_index["metadata"]["total_size"] += added_size
        
    with open(new_index_path, 'w', encoding='utf-8') as f:
        json.dump(new_index, f, indent=2)
        
    print(f"索引文件 {new_index_path.name} 已更新。")

    desc_path = OUTPUT_PATH / "quant_model_description.json"
    if desc_path.exists():
        try:
            with open(desc_path, 'r', encoding='utf-8') as f:
                desc_data = json.load(f)
            
            # 尝试常见字段名,如果存在文件列表则追加
            updated = False
            if isinstance(desc_data, dict):
                for key in ["weight_files", "files", "safetensors_files"]:
                    if key in desc_data and isinstance(desc_data[key], list):
                        if mtp_output_filename not in desc_data[key]:
                            desc_data[key].append(mtp_output_filename)
                            updated = True
                            print(f"已更新描述文件中的 {key} 列表。")
                
            if updated:
                with open(desc_path, 'w', encoding='utf-8') as f:
                    json.dump(desc_data, f, indent=2)
        except Exception as e:
            print(f"警告:无法自动更新 quant_model_description.json ({e})。请手动检查该文件是否需要添加 {mtp_output_filename}。")

    print("\n=== 处理完成 ===")
    print(f"新权重目录:{OUTPUT_PATH}")
    print(f"包含:原有的 W8A8 权重 + 新增的 {mtp_output_filename} (BF16 MTP 层)")

if __name__ == "__main__":
    main()

update_description.py:

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
仅更新 quant_model_description.json,添加 MTP 层权重条目
无需复制原始权重文件
"""

import json
from pathlib import Path

# ================= 配置区域 =================
# 原始权重索引文件(用于获取 MTP 层的键名)
ORIG_INDEX_FILE = Path("/path/to/Step-3.5-Flash/model.safetensors.index.json")

# 合并后的量化描述文件
DESC_FILE = Path("/path/to/Step-3.5-Flash-w8a8-with-mtp/quant_model_description.json")

# MTP 层号
MTP_LAYERS = ["45", "46", "47"]
# ===========================================

def main():
    print(f"--- 更新 quant_model_description.json ---")
    
    # 1. 从原始索引中获取 MTP 层的权重键名
    if not ORIG_INDEX_FILE.exists():
        raise FileNotFoundError(f"未找到原始权重索引文件:{ORIG_INDEX_FILE}")
    
    with open(ORIG_INDEX_FILE, 'r', encoding='utf-8') as f:
        orig_index = json.load(f)
    
    weight_map = orig_index.get("weight_map", {})
    mtp_keys = []
    prefix_patterns = [f"model.layers.{i}." for i in MTP_LAYERS]
    
    for key in weight_map.keys():
        for prefix in prefix_patterns:
            if key.startswith(prefix):
                mtp_keys.append(key)
                break
    
    if not mtp_keys:
        print("错误:未找到 MTP 层权重键名")
        return
    
    print(f"找到 {len(mtp_keys)} 个 MTP 层权重键名")
    
    # 2. 读取并更新 quant_model_description.json
    if not DESC_FILE.exists():
        raise FileNotFoundError(f"未找到描述文件:{DESC_FILE}")
    
    with open(DESC_FILE, 'r', encoding='utf-8') as f:
        desc_data = json.load(f)
    
    # 3. 添加 MTP 层权重,标记为 FLOAT
    updated_count = 0
    for key in mtp_keys:
        if key not in desc_data:
            desc_data[key] = "FLOAT"
            updated_count += 1
        elif desc_data[key] != "FLOAT":
            print(f"警告:{key} 已存在,类型为 {desc_data[key]},将覆盖为 FLOAT")
            desc_data[key] = "FLOAT"
            updated_count += 1
    
    # 4. 保存更新后的文件
    with open(DESC_FILE, 'w', encoding='utf-8') as f:
        json.dump(desc_data, f, indent=2, ensure_ascii=False)
    
    print(f"\n=== 更新完成 ===")
    print(f"文件:{DESC_FILE}")
    print(f"新增/修改条目数:{updated_count}")
    print(f"\n示例(前 5 个 MTP 权重):")
    for key in mtp_keys[:5]:
        print(f"  \"{key}\": \"{desc_data[key]}\"")

if __name__ == "__main__":
    main()

3 精度测试结果

模型名量化格式数据集测试精度 %官方精度 %
Step-3.5-Flash-w8a8-mtpw8a8aime2597.597.3
Step-3.5-Flash-w8a8-mtpw8a8aime2695.8396.7

表中数据为断图mtp下,pass@4的结果