HuggingFace镜像/codeparrot-small-openmind
模型介绍文件和版本分析

CodeParrot 🦜 (small)

CodeParrot 🦜 是一个 GPT-2 模型(1.1 亿参数),专门训练用于生成 Python 代码。

如何在 openmind 中使用

from openmind import AutoTokenizer, AutoModelForCausalLM, is_torch_npu_available
from openmind_hub import snapshot_download
import torch.nn.functional as F
from torch import Tensor
import argparse

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_name_or_path",
        type=str,
        help="Path to model",
        default="jeffding/codeparrot-small-openmind",
    )
    args = parser.parse_args()
    return args

def main():
    args = parse_args()
    model_path = args.model_name_or_path

    if is_torch_npu_available():
        device = "npu:0"
    else:
        device = "cpu"
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForCausalLM.from_pretrained(model_path)
    model = model.to(device)
    prompt = "def hello_world():"
    inputs = tokenizer(prompt, return_tensors="pt", return_token_type_ids=False).to(device)
    out = model.generate(**inputs, max_new_tokens=80).ravel()
    out = tokenizer.decode(out)
    print(out)
    
if __name__ == "__main__":
    main()

使用方法

您可以直接在 transformers 中加载 CodeParrot 模型和分词器:

from transformers import AutoTokenizer, AutoModelWithLMHead
  
tokenizer = AutoTokenizer.from_pretrained("codeparrot/codeparrot-small")
model = AutoModelWithLMHead.from_pretrained("codeparrot/codeparrot-small")

inputs = tokenizer("def hello_world():", return_tensors="pt")
outputs = model(**inputs)

或使用 pipeline:

from transformers import pipeline

pipe = pipeline("text-generation", model="codeparrot/codeparrot-small")
outputs = pipe("def hello_world():")

训练

该模型在经过清洗的 CodeParrot 🦜 数据集 上进行训练,训练设置如下:

配置值
批处理大小192
上下文长度1024
训练步数150,000
梯度累积1
梯度检查点False
学习率5e-4
权重衰减0.1
预热步数2000
学习率调度Cosine

训练在 16 块 A100 (40GB) GPU 上进行。此设置下,模型大约处理了 290 亿个 tokens。

性能

我们在 OpenAI 的 HumanEval 基准测试集上对模型进行了评估,该基准包含一系列编程挑战:

指标值
pass@13.80%
pass@106.57%
pass@10012.78%

pass@k 指标 表示在 k 次生成结果中至少有一个通过测试的概率。

资源

  • 数据集:完整、训练集、验证集
  • 代码:仓库
  • Spaces:生成、高亮
下载使用量0