Ascend-SACT/Rex-Omni
模型介绍文件和版本Pull Requests讨论分析
下载使用量0

Rex-Omni部署指导

引言

Rex-Omni是粤港澳大湾区数字经济研究院(IDEA)计算机视觉与机器人研究中心于2025年10月发布的一款3B参数多模态大语言模型,专注于解决现有MLLM在目标检测中"语言理解强但空间定位弱"的痛点。它通过统一的下一个点预测框架,在多项视觉感知任务上首次实现了对传统回归模型(如Grounding DINO)的超越。本文记录了该模型的开箱适配过程。

一、运行环境准备

1、版本配套表

配套版本环境准备指导
Python3.11.14-
torch2.9.0-
torch_npu2.9.0-
vLLM- Ascend0.14.0rc1-
CANN8.5.0-

2、环境准备

整机:Atlas 800T A2

NPU:910B昇腾

部署方式:单卡部署

操作系统:openEuler 22.03 (LTS-SP2), ARM

二、推理部署


1、镜像下载

docker pull quay.io/ascend/vllm-ascend:v0.14.0rc1

2、权重和软件下载

mkdir -p /opt/data/models/IDEA-Research/Rex-Omni
modelscope download --model IDEA-Research/Rex-Omni --local_dir /opt/data/models/IDEA-Research/Rex-Omni

mkdir -p /home/workDir/00_Software
cd /home/workDir/00_Software
git clone https://github.com/IDEA-Research/Rex-Omni.git

3、启动容器

export IMAGE=quay.io/ascend/vllm-ascend:v0.14.0rc1

docker run --rm \
--name vllm-rex-omni-npu-v14 \
--privileged \
--net=host \
--shm-size=1g \
--device /dev/davinci7 \
--device /dev/davinci_manager \
--device /dev/devmm_svm \
--device /dev/hisi_hdc \
-v /usr/local/dcmi:/usr/local/dcmi \
-v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \
-v /usr/local/Ascend/driver/lib64/:/usr/local/Ascend/driver/lib64/ \
-v /usr/local/Ascend/driver/version.info:/usr/local/Ascend/driver/version.info \
-v /etc/ascend_install.info:/etc/ascend_install.info \
-v /root/.cache:/root/.cache \
-v /opt/data/models:/opt/data/models \
-v /home/workDir:/home/workDir \
-it $IMAGE bash

4、部署与验证

4.1 安装依赖包

4.1.1 安装依赖包

使用的镜像已预装相应版本的vllm、transformer等,因此部分依赖无需额外安装,请按以下方式修改:注释掉/home/workDir/00_Software/Rex-Omni/requirements.txt中的如下几行。

安装软件包:

cd /home/workDir/00_Software/Rex-Omni
pip install -r requirements.txt

pip install -v -e .

4.2、离线部署

保存下面脚本到validate_model.py

#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
Rex-Omni Model Validation Script
Focus on verifying model correctness with vLLM backend
"""

import os
import logging
from PIL import Image
from typing import Dict, Any

import torch
from vllm import LLM, SamplingParams
from transformers import AutoProcessor
from qwen_vl_utils import process_vision_info

from rex_omni.parser import parse_prediction
from rex_omni import RexOmniVisualize

os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)


class RexOmniValidator:
    """Validator for Rex-Omni model"""

    def __init__(
        self,
        model_path: str,
        max_model_len: int = 4096,
        gpu_memory_utilization: float = 0.8,
    ):
        self.model_path = model_path
        self.max_model_len = max_model_len
        self.gpu_memory_utilization = gpu_memory_utilization
        self.model = None
        self.processor = None
        self.sampling_params = None

    def initialize(self) -> bool:
        """Initialize model and processor"""
        logger.info(f"Initializing model from: {self.model_path}")

        try:
            # 初始化vLLM模型,使用半精度浮点数以节省内存
            self.model = LLM(
                model=self.model_path,
                tokenizer=self.model_path,
                trust_remote_code=True,
                max_model_len=self.max_model_len,
                gpu_memory_utilization=self.gpu_memory_utilization,
                tensor_parallel_size=1,
                enforce_eager=True,
                dtype=torch.float16,
                # 重要:根据Rex-Omni提供的模型权重文件列表,这里需要配置tokenizer_mode为slow,否则会造成推理异常。
                tokenizer_mode="slow",
            )
            logger.info("vLLM model loaded successfully")

            # 初始化处理器,设置图像像素范围
            self.processor = AutoProcessor.from_pretrained(
                self.model_path,
                min_pixels=16 * 28 * 28,
                max_pixels=256 * 28 * 28,
                trust_remote_code=True
            )
            self.processor.tokenizer.padding_side = "left"
            logger.info("Processor loaded successfully")

            # 配置采样参数,使用确定性生成
            self.sampling_params = SamplingParams(
                max_tokens=2048,
                temperature=0.0,  # 0表示确定性输出
                top_p=0.05,
                top_k=1,
                repetition_penalty=1.05,
                skip_special_tokens=False,
            )

            return True

        except Exception as e:
            logger.error(f"Failed to initialize model: {e}")
            return False

    def run_inference(
        self,
        image_path: str,
        task: str = "detection",
        categories: str = "person"
    ) -> Dict[str, Any]:
        """Run inference on a single image"""
        logger.info(f"Loading image: {image_path}")

        # 加载并预处理图像
        image = Image.open(image_path).convert("RGB")
        w, h = image.size
        logger.info(f"Image size: {w}x{h}")

        # 构建提示词,指定检测类别和输出格式
        prompt = f"Detect {categories}. Output the bounding box coordinates in [x0, y0, x1, y1] format."

        # 构建多模态消息格式
        messages = [
            {"role": "system", "content": "You are a helpful assistant"},
            {
                "role": "user",
                "content": [
                    {
                        "type": "image",
                        "image": image,
                        "min_pixels": 16 * 28 * 28,
                        "max_pixels": 256 * 28 * 28,
                    },
                    {"type": "text", "text": prompt}
                ]
            }
        ]

        # 应用聊天模板
        text = self.processor.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        logger.info(f"Prompt generated, length: {len(text)} chars")

        # 处理视觉信息
        image_inputs, _ = process_vision_info(messages)

        # 构建LLM输入
        llm_inputs = {
            "prompt": text,
            "multi_modal_data": {"image": image_inputs}
        }

        # 执行推理
        logger.info("Starting inference...")
        outputs = self.model.generate([llm_inputs], sampling_params=self.sampling_params)

        # 获取生成结果
        generated_text = outputs[0].outputs[0].text
        logger.info(f"Generation completed, output length: {len(generated_text)} chars")

        # 解析预测结果
        extracted_predictions = parse_prediction(
            text=generated_text,
            w=w,
            h=h,
            task_type=task
        )
        logger.info(f"Parsed predictions: {extracted_predictions}")

        return {
            "image_size": (w, h),
            "prompt": prompt,
            "raw_output": generated_text,
            "predictions": extracted_predictions,
            "image": image
        }

    def visualize_results(
        self,
        result: Dict[str, Any],
        output_path: str
    ) -> bool:
        """Visualize and save detection results"""
        try:
            # 初始化可视化工具并保存结果
            vis = RexOmniVisualize(
                image=result["image"],
                predictions=result["predictions"],
                font_size=20,
                draw_width=5,
                show_labels=True,
            )
            vis.save(output_path)
            logger.info(f"Visualization saved to: {output_path}")
            return True
        except Exception as e:
            logger.error(f"Failed to visualize: {e}")
            return False


def main():
    """Main validation entry point"""
    # 模型和测试图像路径
    model_path = "/opt/data/models/IDEA-Research/Rex-Omni"
    image_path = "/home/workDir/00_Software/Rex-Omni/tutorials/detection_example/test_images/boys.jpg"
    output_path = os.path.join(os.path.dirname(__file__), "validation_output.jpg")

    logger.info("Starting Rex-Omni model validation")

    # 初始化验证器并加载模型
    validator = RexOmniValidator(model_path=model_path)

    if not validator.initialize():
        logger.error("Model initialization failed")
        return

    # 运行推理
    result = validator.run_inference(
        image_path=image_path,
        task="detection",
        categories="person"
    )

    # 打印原始输出
    logger.info(f"Raw output: {result['raw_output'][:500]}...")

    # 可视化结果
    validator.visualize_results(result, output_path)

    logger.info("Validation completed successfully")


if __name__ == "__main__":
    main()

然后执行

python validate_model.py

输出如下: 会输出一个validation_output.jpg的图片,图片内容如下:

4.3 服务化部署

export ASCEND_RT_VISIBLE_DEVICES=2
# 然后直接运行你的 vllm serve 命令
vllm serve /opt/data/models/IDEA-Research/Rex-Omni \
    --max-model-len 40960 \
    --gpu-memory-utilization 0.8 \
    --dtype float16 \
    --tokenizer-mode slow \
    --trust-remote-code \
    --host 0.0.0.0 \
    --limit-mm-per-prompt '{"image": 10}' \
    --port 8000
4.3.1 文本测试:
curl http://localhost:8000/v1/chat/completions \
    -H "Content-Type: application/json" \
    -d '{
        "model": "/opt/data/models/IDEA-Research/Rex-Omni",
        "messages": [{"role": "user", "content": "请介绍一下成都"}],
        "max_tokens": 1000,
        "temperature": 0.0,
        "skip_special_tokens": false
    }'        

4.3.2 图片测试:
import base64
import requests
import os
import logging
from PIL import Image
from typing import Dict, Any

from rex_omni.parser import parse_prediction
from rex_omni import RexOmniVisualize

# 配置日志
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)

def run_inference(image_path: str) -> Dict[str, Any]:
    """Run inference via vLLM server"""
    logger.info(f"Loading image: {image_path}")

    # 读取并编码图片为base64格式,用于HTTP传输
    with open(image_path, "rb") as f:
        img = base64.b64encode(f.read()).decode()

    # 加载图片获取尺寸,用于后续解析预测结果
    image = Image.open(image_path).convert("RGB")
    w, h = image.size
    logger.info(f"Image size: {w}x{h}")

    # 发送请求到 vLLM 服务器
    logger.info("Sending request to vLLM server...")
    res = requests.post("http://127.0.0.1:8000/v1/chat/completions", json={
        "model": "/opt/data/models/IDEA-Research/Rex-Omni",
        "messages": [{
            "role": "user",
            "content": [
                # 以base64格式发送图像
                {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{img}"}},
                # 发送检测任务的提示词
                {"type": "text", "text": "Detect person. Output the bounding box coordinates in [x0, y0, x1, y1] format."}
            ]
        }],
        "temperature": 0,  # 确定性输出
        "max_tokens": 40000,  # 最大生成token数
        # 重要:这里必须设置skip_special_tokens为False,否则在解析输出时会跳过特殊字符对应的Token,导致处理异常。
        "skip_special_tokens": False
    })

    # 解析服务器响应
    response = res.json()
    logger.info("Received response from server")

    # 提取生成的文本
    if "choices" in response and response["choices"]:
        generated_text = response["choices"][0]["message"]["content"]
        logger.info(f"Generation completed, output length: {len(generated_text)} chars")

        # 解析预测结果,将模型输出转换为标准格式
        extracted_predictions = parse_prediction(
            text=generated_text,
            w=w,  # 图像宽度
            h=h,  # 图像高度
            task_type="detection"  # 任务类型
        )
        logger.info(f"Parsed predictions: {extracted_predictions}")

        return {
            "image_size": (w, h),
            "raw_output": generated_text,
            "predictions": extracted_predictions,
            "image": image
        }
    else:
        logger.error(f"Failed to get response: {response}")
        return None

def visualize_results(result: Dict[str, Any], output_path: str) -> bool:
    """Visualize and save detection results"""
    try:
        # 初始化可视化工具并保存结果
        vis = RexOmniVisualize(
            image=result["image"],
            predictions=result["predictions"],
            font_size=20,
            draw_width=5,
            show_labels=True,
        )
        vis.save(output_path)
        logger.info(f"Visualization saved to: {output_path}")
        return True
    except Exception as e:
        logger.error(f"Failed to visualize: {e}")
        return False

def main():
    """Main entry point"""
    # 设置测试图像和输出路径
    image_path = "/home/workDir/00_Software/Rex-Omni/tutorials/detection_example/test_images/boys.jpg"
    output_path = os.path.join(os.path.dirname(__file__), "vllm_server_output.jpg")

    logger.info("Starting Rex-Omni model validation via vLLM server")

    # 运行推理
    result = run_inference(image_path)

    if result:
        # 打印原始输出(前500字符)
        logger.info(f"Raw output: {result['raw_output'][:500]}...")

        # 可视化结果
        visualize_results(result, output_path)

        logger.info("Validation completed successfully")
    else:
        logger.error("Inference failed")

if __name__ == "__main__":
    main()

运行后,会输出一张处理后的照片,照片处理的结果和离线处理的一样