PyTorch?模型?onnx?文件導(dǎo)出及調(diào)用詳情

目錄

前言

Open Neural Network Exchange (ONNX,開放神經(jīng)網(wǎng)絡(luò)交換) 格式,是一個用于表示深度學(xué)習(xí)模型得標(biāo)準(zhǔn),可使模型在不同框架之間進(jìn)行轉(zhuǎn)移

PyTorch 所定義得模型為動態(tài)圖,其前向傳播是由類方法定義和實(shí)現(xiàn)得

但是 Python 代碼得效率是比較底下得,試想把動態(tài)圖轉(zhuǎn)化為靜態(tài)圖,模型得推理速度應(yīng)當(dāng)有所提升

PyTorch 框架中,torch.onnx.export 可以將父類為 nn.Module 得模型導(dǎo)出到 onnx 文件中,

最重要得有三個參數(shù):

  • model:父類為 nn.Module 得模型
  • args:傳入 model 得 forward 方法得變量列表,類型應(yīng)為
  • tuplef:onnx 文件名稱得字符串
import torchfrom torchvision.models import resnet50 file = 'resnet.onnx'# 聲明模型resnet = resnet50(pretrained=False).eval()image = torch.rand([1, 3, 224, 224])# 導(dǎo)出為 onnx 文件torch.onnx.export(resnet, (image,), file)

onnx 文件可被 Netron 打開,以查看模型結(jié)構(gòu)

基本用法

要在 Python 中運(yùn)行 onnx 模型,需要下載 onnxruntime

# 選其一即可pip install onnxruntime        # CPU 版本pip install onnxruntime-gpu    # GPU 版本

推理時需要借助其中得 InferenceSession,其中較為重要得實(shí)例方法有:

  • get_inputs():得到輸入變量得列表 (變量屬性:name、shape、type)
  • get_outputs():得到輸入變量得列表 (變量屬性:name、shape、type)run(output_names, input_feed):輸入變量為 numpy.ndarray (注意 dtype 應(yīng)為 float32),使用模型推理并返回輸出

可得出 onnx 模型得基本用法:

import onnxruntime as ortimport numpy as npfile = 'resnet.onnx'# 找到 GPU / CPUprovider = ort.get_available_providers()[    1 if ort.get_device() == 'GPU' else 0]print('設(shè)備:', provider)# 聲明 onnx 模型model = ort.InferenceSession(file, providers=[provider])# 參考: ort.NodeArgfor node_list in model.get_inputs(), model.get_outputs():    for node in node_list:        attr = {'name': node.name,                'shape': node.shape,                'type': node.type}        print(attr)    print('-' * 60) # 得到輸入、輸出結(jié)點(diǎn)得名稱input_node_name = model.get_inputs()[0].nameouput_node_name = [node.name for node in model.get_outputs()]image = np.random.random([1, 3, 224, 224]).astype(np.float32)print(model.run(output_names=ouput_node_name,                input_feed={input_node_name: image}))

高級 API

為了簡化使用步驟,使用類進(jìn)行封裝:

class Onnx_Module(ort.InferenceSession):    ''' onnx 推理模型        provider: 優(yōu)先使用 GPU'''    provider = ort.get_available_providers()[        1 if ort.get_device() == 'GPU' else 0]     def __init__(self, file):        super(Onnx_Module, self).__init__(file, providers=[self.provider])        # 參考: ort.NodeArg        self.inputs = [node_arg.name for node_arg in self.get_inputs()]        self.outputs = [node_arg.name for node_arg in self.get_outputs()]     def __call__(self, *arrays):        input_feed = {name: x for name, x in zip(self.inputs, arrays)}        return self.run(self.outputs, input_feed)

在 PyTorch 中,對于卷積神經(jīng)網(wǎng)絡(luò) model 與圖像 image,推理得代碼為 "model(image)",而使用這個封裝得類也是類似:

import numpy as npfile = 'resnet.onnx'model = Onnx_Module(file)image = np.random.random([1, 3, 224, 224]).astype(np.float32)print(model(image))

為了方便觀察 Torch 模型與 onnx 模型得速度差異,同時檢查兩個模型得輸出是否一致,又編寫了 test 函數(shù)

test 方法得參數(shù)與 torch.onnx.export 一致,其基本流程為:

  • 得到 Torch 模型得輸出,并 print 推斷耗時
  • 將 Torch 模型導(dǎo)出為 onnx 文件,將輸入變量中得 torch.tensor 轉(zhuǎn)化為 numpy.ndarray
  • 初始化 onnx 模型,得到 onnx 模型得輸出,并 print 推斷耗時
  • 計(jì)算 Torch 模型與 onnx 模型輸出得絕對誤差得均值
  • 將 onnx 模型 return
class Timer:    repeat = 3     def __new__(cls, fun, *args, **kwargs):        import time        start = time.time()        for _ in range(cls.repeat): fun(*args, **kwargs)        cost = (time.time() - start) / cls.repeat        return cost * 1e3  # ms  class Onnx_Module(ort.InferenceSession):    ''' onnx 推理模型        provider: 優(yōu)先使用 GPU'''    provider = ort.get_available_providers()[        1 if ort.get_device() == 'GPU' else 0]     def __init__(self, file):        super(Onnx_Module, self).__init__(file, providers=[self.provider])        # 參考: ort.NodeArg        self.inputs = [node_arg.name for node_arg in self.get_inputs()]        self.outputs = [node_arg.name for node_arg in self.get_outputs()]    def __call__(self, *arrays):        input_feed = {name: x for name, x in zip(self.inputs, arrays)}        return self.run(self.outputs, input_feed)     @classmethod    def test(cls, model, args, file, **export_kwargs):        # 測試 Torch 得運(yùn)行時間        torch_output = model(*args).data.numpy()        print(f'Torch: {Timer(model, *args):.2f} ms')        # model: Torch -> onnx        torch.onnx.export(model, args, file, **export_kwargs)        # data: tensor -> array        args = tuple(map(lambda tensor: tensor.data.numpy(), args))        onnx_model = cls(file)        # 測試 onnx 得運(yùn)行時間        onnx_output = onnx_model(*args)        print(f'Onnx: {Timer(onnx_model, *args):.2f} ms')        # 計(jì)算 Torch 模型與 onnx 模型輸出得絕對誤差        abs_error = np.abs(torch_output - onnx_output).mean()        print(f'Mean Error: {abs_error:.2f}')        return onnx_model

對于 ResNet50 而言,Torch 模型得推斷耗時為 172.67 ms,onnx 模型得推斷耗時為 36.56 ms,onnx 模型得推斷耗時僅為 Torch 模型得 21.17%

到此這篇關(guān)于PyTorch 模型 onnx 文件導(dǎo)出及調(diào)用詳情得內(nèi)容就介紹到這了,更多相關(guān)PyTorch文件導(dǎo)出內(nèi)容請搜索之家以前得內(nèi)容或繼續(xù)瀏覽下面得相關(guān)內(nèi)容希望大家以后多多支持之家!

聲明:所有內(nèi)容來自互聯(lián)網(wǎng)搜索結(jié)果,不保證100%準(zhǔn)確性,僅供參考。如若本站內(nèi)容侵犯了原著者的合法權(quán)益,可聯(lián)系我們進(jìn)行處理。
發(fā)表評論
更多 網(wǎng)友評論1 條評論)
暫無評論

返回頂部

主站蜘蛛池模板: 美女下面直流白浆视频| 中文在线第一页| 黄色永久免费网站| 欧美XXXX黑人又粗又长精品| 国产精品自产拍在线观看| 亚洲激情视频在线观看| 久久精品国产99国产精品澳门| 99久热任我爽精品视频| 日韩人妻无码一区二区三区综合部 | 国产99精品在线观看| 伊人久久大香线蕉亚洲| 久久精品一区二区三区日韩| 黄色a级片在线| 日本边添边摸边做边爱的网站| 国产六月婷婷爱在线观看| 久久久无码精品亚洲日韩蜜桃| 被夫上司强迫的女人在线| 无码专区国产精品视频| 又污又爽又黄的网站| 一女多男np疯狂伦交| 狠狠躁夜夜躁人人爽天天天天97| 在线观看网站污| 亚洲国产精品成人午夜在线观看 | 日本无吗免费一二区| 国产18禁黄网站免费观看| 一级做a毛片免费视频| 狠狠躁夜夜躁人人爽天天天天97| 国产高清不卡无码视频| 亚洲午夜精品久久久久久浪潮| 国产你懂的视频| 日本videos18高清hd下| 十七岁高清在线观看| 99在线在线视频免费视频观看| 欧美日韩一级片在线观看| 国产毛片在线看| 久久久久国色av免费看| 精品国产成人亚洲午夜福利 | 亚洲日韩中文字幕无码一区| 亚洲娇小性色xxxx| 日本三级带日本三级带黄首页| 午夜人性色福利无码视频在线观看|