Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GLM 10B CN推理加速耗时 #479

Open
vicwer opened this issue Mar 19, 2023 · 1 comment
Open

GLM 10B CN推理加速耗时 #479

vicwer opened this issue Mar 19, 2023 · 1 comment

Comments

@vicwer
Copy link

vicwer commented Mar 19, 2023

您好,使用libai做glm-10b-chinese推理加速,目前现象:libai 2卡推理耗时是huggingface单卡耗时的两倍(0.6s vs 0.3s),请帮忙分析一下原因,多��。
libai推理代码:python3 -m oneflow.distributed.launch --nproc_per_node 2 demo.py

# model parallel + pipeline parallel demo

import oneflow as flow
from projects.GLM.tokenizer.glm_tokenizer import GLMChineseTokenzier
from libai.utils import distributed as dist
from projects.GLM.configs.glm_inference import cfg
from projects.GLM.modeling_glm import GLMForConditionalGeneration
from projects.GLM.utils.glm_loader import GLMLoaderHuggerFace
from omegaconf import DictConfig
import time

# 只需简单配置并行方案
# parallel_config = DictConfig(
#     dict(
#         data_parallel_size=1,
#         tensor_parallel_size=2,
#         pipeline_parallel_size=2,
#         pipeline_num_layers=2 * 24
#     )
# )
parallel_config = DictConfig(
    dict(
        data_parallel_size=1,
        tensor_parallel_size=2,
        pipeline_parallel_size=1   
    )
)
dist.setup_dist_util(parallel_config)

tokenizer = GLMChineseTokenzier.from_pretrained("./models/glm/glm_10b_cn")
input_ids = tokenizer.encode(
    [
        "橘子的颜色是[MASK]。"
    ],
    return_tensors="of",
)
inputs = {"input_ids": input_ids, "attention_mask": flow.ones(input_ids.size())}
inputs = tokenizer.build_inputs_for_generation(inputs, max_gen_length=512)

sbp = dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast])
placement = dist.get_layer_placement(0)

loader = GLMLoaderHuggerFace(
    GLMForConditionalGeneration, 
    cfg, 
    "./models/glm/glm_10b_cn",
    embedding_dropout_prob=0,
    attention_dropout_prob=0,
    output_dropout_prob=0,
)
model = loader.load()

while True:
    t0 = time.time()
    outputs = model.generate(
        inputs=inputs['input_ids'].to_global(sbp=sbp, placement=placement), 
        position_ids=inputs['position_ids'].to_global(sbp=sbp, placement=placement), 
        generation_attention_mask=inputs['generation_attention_mask'].to_global(sbp=sbp, placement=placement), 
        max_length=512
    )
    torch.cuda.synchronize()
    if dist.is_main_process():
        print("cost time", time.time() - t0)

    res = tokenizer.decode(outputs[0])
    if dist.is_main_process():
        print(res)

huggingface推理代码:

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import time
import torch
ckpt_path = "./models/glm/glm_10b_cn"
#tokenizer = AutoTokenizer.from_pretrained("THUDM/glm-10b", trust_remote_code=True)
#model = AutoModelForSeq2SeqLM.from_pretrained("THUDM/glm-10b", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(ckpt_path, trust_remote_code=True)
model = AutoModelForSeq2SeqLM.from_pretrained(ckpt_path, trust_remote_code=True)
model = model.half().cuda()
model.eval()

# Inference
#inputs = tokenizer("Ng is an adjunct professor at [MASK] (formerly associate professor and Director of its Stanford AI Lab or SAIL ). Also a pioneer in online education, Ng co-founded Coursera and deeplearning.ai.", return_tensors="pt")

while True:
    t0 = time.time()
    inputs = tokenizer("橘子的颜色是[MASK]。", return_tensors="pt")
    inputs = tokenizer.build_inputs_for_generation(inputs, max_gen_length=512)
    inputs = inputs.to('cuda')
    outputs = model.generate(**inputs, max_length=512, eos_token_id=tokenizer.eop_token_id)
    torch.cuda.synchronize()
    print("cost time", time.time() - t0)
    print(tokenizer.decode(outputs[0].tolist()))

推理硬件:

GPU:
2卡x80G A100
CPU:
processor       : 27
cpu family      : 6
model           : 106
model name      : Intel(R) Xeon(R) Platinum 8336C CPU @ 2.30GHz
cpu MHz         : 2294.608
cache size      : 55296 KB

相关库:

LiBai                  0.2.0      /root/code/env/libai
omegaconf              2.1.0
oneflow                0.9.0
torch                  1.12.1
torchaudio             0.12.1
torchvision            0.13.1
cuda Driver Version: 470.129.06 cu113
@xiezipeng-ML
Copy link
Contributor

xiezipeng-ML commented Mar 29, 2023

可以先试试这个pr的单卡推理demo,#484 (comment)
目前我这测试比torch快了
@vicwer

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
2 participants