Pytorch使用Dali进行预处理加速

对于深度学习任务,训练速度决定了模型的迭代速度,而训练速度又取决于数据预处理和网络的前向和后向耗时。
对于识别任务,batch size通常较大,并且需要做数据增强,因此常常导致训练速度的瓶颈在数据读取和预处理上,尤其对于小网络而言。
对于数据读取耗时的提升,粗暴且有效的解决办法是使用固态硬盘,或者将数据直接拷贝至/tmp文件夹(内存空间换时间)。
对于数据预处理的耗时,则可以通过使用Nvidia官方开发的Dali预处理加速工具包,将预处理放在cpu/gpu上进行加速。pytorch1.6版本内置了Dali,无需自己安装。

官方的Dali交程较为简单,实际训练通常要根据任务需要自定义Dataloader,并于分布式训练结合使用。这里将展示一个使用Dali定义DataLoader的例子,功能是返回序列图像,并对序列图像做常见的统一预处理操作。
`

from nvidia.dali.plugin.pytorch import DALIGenericIterator

from nvidia.dali.types import DALIImageType
import cv2
from nvidia.dali.plugin.pytorch import DALIClassificationIterator
from nvidia.dali.pipeline import Pipeline
import nvidia.dali.ops as ops
import nvidia.dali.types as types
from sklearn.utils import shuffle
import numpy as np
from torchvision import transforms
import torch.utils.data as torchdata
import random
from pathlib import Path
import torch

class TRAIN_INPUT_ITER(object):
    def __init__(self, batch_size, num_class,seq_len, sample_rate, num_shards=1, shard_id=0,root_dir=Path('') ,list_file='', is_training=True):
        self.batch_size = batch_size
        self.num_class = num_class
        self.seq_len = seq_len
        self.sample_rate = sample_rate
        self.num_shards = num_shards
        self.shard_id = shard_id
        self.train = is_training
        self.image_name_formatter = lambda x: f'image_{x:05d}.jpg'
        self.root_dir = root_dir
        with open(list_file,'r') as f:
            self.ori_lines = f.readlines()

    def __iter__(self):
        self.i = 0
        bucket = len(self.ori_lines)//self.num_shards
        self.n = bucket
        return self

    def __next__(self):
        batch = [[] for _ in range(self.seq_len)]
        labels = []
        for _ in range(self.batch_size):
            # self.sample_rate = random.randint(1,2)
            if self.train and self.i % self.n == 0:
                bucket = len(self.ori_lines)//self.num_shards
                self.ori_lines= shuffle(self.ori_lines, random_state=0)
                self.lines = self.ori_lines[self.shard_id*bucket:(self.shard_id+1)*bucket]
            line = self.lines[self.i].strip()
            dir_name,start_f,end_f, label = line.split(' ')
            start_f = int(start_f)
            end_f = int(end_f)
            label = int(label)
            begin_frame = random.randint(start_f,max(end_f-self.sample_rate*self.seq_len,start_f))
            begin_frame = max(1,begin_frame)
            last_frame = None
            for k in range(self.seq_len):
                filename = self.root_dir/dir_name/self.image_name_formatter(begin_frame+self.sample_rate*k)
                if filename.exists():
                    f = open(filename,'rb')
                    last_frame = filename
                elif last_frame is not None:
                    f = open(last_frame,'rb')
                else:
                    print('{} does not exist'.format(filename))
                    raise IOError
                batch[k].append(np.frombuffer(f.read(), dtype = np.uint8))
            if random.randint(0,1)%2 == 0:
                end_frame = start_f + random.randint(0,self.sample_rate*self.seq_len//2)
                begin_frame = max(1,end_frame-self.sample_rate*self.seq_len)
            else:
                begin_frame = end_f - random.randint(0,self.sample_rate*self.seq_len//2)
                begin_frame = max(1,begin_frame)
                end_frame = begin_frame + self.sample_rate*self.seq_len
            last_frame = None
            for k in range(self.seq_len):
                filename = self.root_dir/dir_name/self.image_name_formatter(begin_frame+self.sample_rate*k)
                if filename.exists():
                    f = open(filename,'rb')
                    last_frame = filename
                elif last_frame is not None:
                    f = open(last_frame,'rb')
                else:
                    print('{} does not exist'.format(filename))
                    raise IOError
                batch[k].append(np.frombuffer(f.read(), dtype = np.uint8))

            labels.append(np.array([label], dtype = np.uint8))
            if label==8 or label == 9:
                labels.append(np.array([label], dtype = np.uint8))
            else:
                labels.append(np.array([self.num_class-1], dtype = np.uint8))

            self.i = (self.i + 1) % self.n
        return (batch, labels)
    
    next = __next__


class VAL_INPUT_ITER(object):
    def __init__(self, batch_size, num_class,seq_len, sample_rate, num_shards=1, shard_id=0,root_dir=Path('') ,list_file='', is_training=False):
        self.batch_size = batch_size
        self.num_class = num_class
        self.seq_len = seq_len
        self.sample_rate = sample_rate
        self.num_shards = num_shards
        self.shard_id = shard_id
        self.train = is_training
        self.image_name_formatter = lambda x: f'image_{x:05d}.jpg'
        self.root_dir = root_dir
        with open(list_file,'r') as f:
            self.ori_lines = f.readlines()
            self.ori_lines= shuffle(self.ori_lines, random_state=0)

    def __iter__(self):
        self.i = 0
        bucket= len(self.ori_lines)//self.num_shards
        self.n = bucket
        return self

    def __next__(self):
        batch = [[] for _ in range(self.seq_len)]
        labels = []
        for _ in range(self.batch_size):
            # self.sample_rate = random.randint(1,2)
            if self.train and self.i % self.n == 0:
                bucket = len(self.ori_lines)//self.num_shards
                self.ori_lines= shuffle(self.ori_lines, random_state=0)
                self.lines = self.ori_lines[self.shard_id*bucket:(self.shard_id+1)*bucket]
            if self.i % self.n == 0:
                bucket = len(self.ori_lines)//self.num_shards
                self.lines = self.ori_lines[self.shard_id*bucket:(self.shard_id+1)*bucket]
            line = self.lines[self.i].strip()
            dir_name,start_f,end_f, label = line.split(' ')
            start_f = int(start_f)
            end_f = int(end_f)
            label = int(label)
            begin_frame = random.randint(start_f,max(end_f-self.sample_rate*self.seq_len,start_f))
            begin_frame = max(1,begin_frame)
            last_frame = None
            for k in range(self.seq_len):
                filename = self.root_dir/dir_name/self.image_name_formatter(begin_frame+self.sample_rate*k)
                if filename.exists():
                    f = open(filename,'rb')
                    last_frame = filename
                elif last_frame is not None:
                    f = open(last_frame,'rb')
                else:
                    print('{} does not exist'.format(filename))
                    raise IOError
                batch[k].append(np.frombuffer(f.read(), dtype = np.uint8))
            labels.append(np.array([label], dtype = np.uint8))
            self.i = (self.i + 1) % self.n
        return (batch, labels)
    next = __next__

class HybridPipe(Pipeline):
    def __init__(self, batch_size, num_class,seq_len, sample_rate, num_shards,shard_id,root_dir, list_file, num_threads, device_id=0, dali_cpu=True,size = (224,224),is_gray = True,is_training = True):
        super(HybridPipe, self).__init__(batch_size, num_threads, device_id, seed=12 + device_id)
        if is_training:
            self.external_data = TRAIN_INPUT_ITER(batch_size//2, num_class,seq_len,sample_rate,num_shards,shard_id,root_dir, list_file,is_training)
        else:
            self.external_data = VAL_INPUT_ITER(batch_size, num_class,seq_len,sample_rate,num_shards,shard_id,root_dir, list_file,is_training)
        # self.external_data = VAL_INPUT_ITER(batch_size, num_class,seq_len,sample_rate,num_shards,shard_id,root_dir, list_file,is_training)
        self.seq_len = seq_len
        self.training = is_training
        self.iterator = iter(self.external_data)
        self.inputs = [ops.ExternalSource() for _ in range(seq_len)]
        self.input_labels = ops.ExternalSource()
        self.is_gray = is_gray

        decoder_device = 'cpu' if dali_cpu else 'mixed'

        self.decode = ops.ImageDecoder(device=decoder_device, output_type=types.RGB)
        if self.is_gray:
            self.space_converter = ops.ColorSpaceConversion(device='gpu',image_type=types.RGB,output_type=types.GRAY)
        self.resize = ops.Resize(device='gpu', size=size)
        self.cast_fp32 = ops.Cast(device='gpu',dtype = types.FLOAT)
        if self.training:
            self.crop_coin = ops.CoinFlip(probability=0.5)
            self.crop_pos_x = ops.Uniform(range=(0., 1.))
            self.crop_pos_y = ops.Uniform(range=(0., 1.))
            self.crop_h = ops.Uniform(range=(256*0.85,256))
            self.crop_w = ops.Uniform(range=(256*0.85,256))
            self.crmn = ops.CropMirrorNormalize(device="gpu",output_layout=types.NHWC)

            self.u_rotate = ops.Uniform(range=(-8, 8))
            self.rotate = ops.Rotate(device='gpu',keep_size=True)

            self.brightness = ops.Uniform(range=(0.9,1.1))
            self.contrast = ops.Uniform(range=(0.9,1.1))
            self.saturation = ops.Uniform(range=(0.9,1.1))
            self.hue = ops.Uniform(range=(-0.3,0.3))
            self.color_jitter = ops.ColorTwist(device='gpu')
        else:
            self.crmn = ops.CropMirrorNormalize(device="gpu",crop=(224,224),output_layout=types.NHWC)
    

    def define_graph(self):
        self.batch_data = [i() for i in self.inputs]
        self.labels = self.input_labels()
        out = self.decode(self.batch_data)
        out = [out_elem.gpu() for out_elem in out]
        if self.training:
            out = self.color_jitter(out,brightness=self.brightness(),contrast=self.contrast())
        if self.is_gray:
            out = self.space_converter(out)
        if self.training:
            out = self.rotate(out,angle=self.u_rotate())
            out = self.crmn(out,crop_h=self.crop_h(),crop_w=self.crop_w(),crop_pos_x=self.crop_pos_x(),crop_pos_y=self.crop_pos_y(),mirror=self.crop_coin())
        else:
            out = self.crmn(out)
        out = self.resize(out)
        if not self.training:
            out = self.cast_fp32(out)
        return (*out, self.labels)
    
    def iter_setup(self):
        try:
            (batch_data, labels) = self.iterator.next()
            for i in range(self.seq_len):
                self.feed_input(self.batch_data[i], batch_data[i])
            self.feed_input(self.labels, labels)

        except StopIteration:
            self.iterator = iter(self.external_data)
            raise StopIteration

def dali_loader(batch_size,
                num_class,
                seq_len,
                sample_rate,
                num_shards,
                shard_id,
                root_dir,
                list_file,
                num_workers,
                device_id,
                dali_cpu=True,
                size = (224,224),
                is_gray = True,
                is_training=True):
    print('##########',root_dir)
    pipe = HybridPipe(batch_size,num_class,seq_len,sample_rate,num_shards,shard_id,root_dir,
                        list_file,num_workers,device_id=device_id,
                        dali_cpu=dali_cpu,size = size,is_gray=is_gray,is_training=is_training)
    # pipe.build()
    names = []
    for i in range(seq_len):
        names.append(f'data{i}')
    names.append('label')
    print('##############',names)
    loader = DALIGenericIterator(pipe,names,pipe.external_data.n,last_batch_padded=True, fill_last_batch=True)
    return loade

r`

上一篇:python enumerate() 函数给可遍历的数据对象添加索引


下一篇:「网工小白必备」全网最详细的网络基础分解与解析(2)