Keras ImageDataGenerator 缓慢

9
我正在寻找在Keras中处理大于内存数据的最佳方法,目前注意到普通的ImageDataGenerator比我希望的要慢。我有两个网络在Kaggle的猫和狗数据集(25000张图片)上进行训练:
1)这种方法正是来自:http://www.pyimagesearch.com/2016/09/26/a-simple-neural-network-with-python-and-keras/的代码。
2)与(1)相同,但使用ImageDataGenerator而不是将数据加载到内存中。
注意:下面的“预处理”意味着调整大小、缩放、扁平化。
在我的gtx970上,我发现以下结果:
对于网络1,每个epoch需要约0秒。
对于网络2,如果预处理在数据生成器中完成,则每个epoch需要约36秒。
对于网络2,如果预处理在第一次通过数据生成器之外完成,则每个epoch需要约13秒。
这是否可能是ImageDataGenerator的速度限制(13秒似乎是磁盘和内存之间通常10-100倍差异的情况)?在使用Keras训练大于内存数据时,是否有更适合的方法/机制? 例如,也许有一种方法可以让Keras中的ImageDataGenerator在第一次epoch后保存其处理过的图像吗?
谢谢!

1
虽然这篇文章有点旧了,但仍然相关:Slow image data generator。该帖子指出Keras(至少在过去的某个时候)应用了多个顺序变换,而实际上只需要使用一个单一的变换即可。 - user3731622
1
请查看此链接:https://github.com/stratospark/keras-multiprocess-image-data-generator/blob/master/Accelerating%20Deep%20Learning%20with%20Multiprocess%20Image%20Augmentation%20in%20Keras.md - Amir Saniyan
2个回答

3
我假设您可能已经解决了这个问题,但是不管怎样...
Keras图像预处理有保存结果的选项,可以通过在flow()flow_from_directory()函数中设置save_to_dir参数来实现: https://keras.io/preprocessing/image/

(在可视化编程中)例如流程>参数>保存到目录,这个功能非常有用。 - M at

3
在我的理解中,问题在于增强图像仅在模型的一个训练周期中使用一次,甚至没有跨越多个时期。因此,在CPU努力工作的同时,GPU很浪费资源。 我找到了以下解决方案:
  1. 我尽可能在RAM中生成多个增强图像
  2. 我使用它们进行训练,跨越几个时期,10到30个时期,直到出现明显的收敛
  3. 之后,我将生成新批次的增强图像(通过实现on_epoch_end),处理会继续进行。
大部分时间这种方法都能保持GPU繁忙,并且能够从数据增强中获益。我使用自定义Sequence子类来生成增强并同时修复类别不平衡问题。
编辑:添加代码以澄清这个想法
from pyutilz.string import read_config_file
from tqdm.notebook import tqdm
from gc import collect
import numpy as np
import tensorflow
import random
import cv2

class StoppingFromFile(tensorflow.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        if read_config_file('control.ini','ML','stop',globals()):        
            if stop is not None:        
                if stop==True or stop=='True':
                    logging.warning(f'Model should be stopped according to the control fole')
                    self.model.stop_training = True

class AugmentedBalancedSequence(tensorflow.keras.utils.Sequence):
    def __init__(self, images_and_classes:dict,input_size:tuple,class_sizes:list, augmentations_fn:object, preprocessing_fn:object, batch_size:int=10,
                 num_class_samples=100, frame_length:int=5, aug_p:float=0.1,aug_pipe_p:float=0.2,is_validation:bool=False,
                disk_saving_prob:float=.01,disk_example_nfiles:int=50):
        """
            From a dict of file paths grouped by class label, creates each N epochs augmented balanced training set.
            If current class is too scarce, ensures that current frame has no duplicate final images.
            If it's rich enough, ensures that current frame has no duplicate base images.
        
        """
        logging.info(f'Got {len(images_and_classes)} classes.')
        self.disk_example_nfiles=disk_example_nfiles;self.disk_saving_prob=disk_saving_prob;self.cur_example_file=0
        
        self.images_and_classes=images_and_classes        
        self.num_class_samples=num_class_samples
        self.augmentations_fn=augmentations_fn
        self.preprocessing_fn=preprocessing_fn
        
        self.is_validation=is_validation
        self.frame_length=frame_length                    
        self.batch_size = batch_size      
        self.class_sizes=class_sizes
        self.input_size=input_size        
        self.aug_pipe_p=aug_pipe_p
        self.aug_p=aug_p        
        self.images=None
        self.epoch = 0
        #print(f'got frame_length={self.frame_length}')
        self._generate_data()
        

    def __len__(self):
        return int(np.ceil(len(self.images)/ float(self.batch_size)))

    def __getitem__(self, idx):
        a=idx * self.batch_size;b=a+self.batch_size
        return self.images[a:b],self.labels[a:b]
    
    def on_epoch_end(self):
        import ast
        self.epoch += 1    
        mydict={}

        import pathlib
        fname='control.json'
        p = pathlib.Path(fname)
        if p.is_file():
            try:
                with open (fname) as f:
                    mydict=json.load(f)
                for var,val in mydict.items():
                    if hasattr(self,var):
                        converted = val #ast.literal_eval(val)
                        if converted is not None:
                            if getattr(self, var)!=converted:
                                setattr(self, var, converted)                                        
                                print(f'{var} became {val}')
            except Exception as e:
                logging.error(str(e))
        if self.epoch % self.frame_length == 0:
            #print('generating data...')
            self._generate_data()
            
    def _add_sample(self,image,label):
        from random import random
        idx=self.indices[self.img_sent]
        
        if self.disk_saving_prob>0:
            if random()<self.disk_saving_prob:
                self.cur_example_file+=1
                if self.cur_example_file>self.disk_example_nfiles:
                    self.cur_example_file=1
                Path(r'example_images/').mkdir(parents=True, exist_ok=True)
                cv2.imwrite(f'example_images/test{self.cur_example_file}.jpg',cv2.cvtColor(image,cv2.COLOR_RGB2BGR))
        
        if self.preprocessing_fn: 
            self.images[idx]=self.preprocessing_fn(image)
        else:
            self.images[idx]=image
        
        self.labels[idx]=label
        self.img_sent+=1        
        
    def _generate_data(self):
        logging.info('Generating new set of augmented data...')
        
        collect()
        #del self.images
        #del self.labels        
        #collect()
        
        if self.num_class_samples:
            expected_length=len(self.images_and_classes)*self.num_class_samples
        else:
            expected_length=sum(self.class_sizes.values())        
            
        if self.images is None:
            self.images=np.empty((expected_length,)+(self.input_size[1],)+(self.input_size[0],)+(3,))
            self.labels=np.empty((expected_length),np.int32)
        
        self.indices=np.random.choice(expected_length, expected_length, replace=False)
        self.img_sent=0
        
        
        collect()
        
        relaxed_augmentation_pipeline=self.augmentations_fn(p=self.aug_p,pipe_p=self.aug_pipe_p)
        maxed_out_augmentation_pipeline=self.augmentations_fn(p=self.aug_p,pipe_p=1.0)
        
        #for each class
        x,y=[],[]
        nartificial=0
        for label,images in tqdm(self.images_and_classes.items()):
            if self.num_class_samples is None:
                #Just all native samples without augmentations
                for image in images:
                    self._add_sample(image,label)                        
            else:
                #if there are enough native samples
                if len(images)>=self.num_class_samples:
                    #randomly select samples of this class which will participate in this frame of epochs                
                    indices=np.random.choice(len(images), self.num_class_samples, replace=False)
                    #apply albumentations pipeline to selected samples

                    for idx in indices:
                        if not self.is_validation:
                            self._add_sample(relaxed_augmentation_pipeline(image=images[idx])['image'],label)
                        else:
                            self._add_sample(images[idx],label)
                                                    
                else:
                    #------------------------------------------------------------------------------------------------------------------------------------------------------------------
                    # Randomly pick next image from existing. try applying augmentation pipeline (with maxed out probability) till we get num_class_samples DIFFERENT images
                    #------------------------------------------------------------------------------------------------------------------------------------------------------------------
                    hashes=set()
                    norig=0
                    while len(hashes)<self.num_class_samples:
                        if self.is_validation and norig<len(images):
                            #just include all originals first
                            image=images[norig]
                        else:
                            image=maxed_out_augmentation_pipeline(image=random.choice(images))['image']                                                      
                        next_hash=np.sum(image)
                        if next_hash not in hashes or (self.is_validation and norig<=len(images)):                        
                            
                            #print(f'Adding orig {norig} out of {self.num_class_samples}, hashes={hashes}')
                            
                            self._add_sample(image,label)
                            if next_hash in hashes:
                                norig+=1
                                hashes.add(norig)
                            else:
                                hashes.add(next_hash)
                                nartificial+=1  
                                
        
        #self.images=self.images[indices];self.labels=self.labels[indices]                              
        
        logging.info(f'Generated {self.img_sent} samples ({nartificial} artificial)')

一旦我加载了图片和类,
train_datagen = AugmentedBalancedSequence(images_and_classes=images_and_classes_train,
                          input_size=INPUT_SIZE,class_sizes=class_sizes_train,num_class_samples=UPSCALE_SAMPLES,
    augmentations_fn=get_albumentations_pipeline,aug_p=AUG_P,aug_pipe_p=AUG_PIPE_P,preprocessing_fn=preprocess_input, batch_size=BATCH_SIZE,frame_length=FRAME_LENGTH,disk_saving_prob=0.05)

val_datagen = AugmentedBalancedSequence(images_and_classes=images_and_classes_val,
                                        input_size=INPUT_SIZE,class_sizes=class_sizes_val,num_class_samples=None,
    augmentations_fn=get_albumentations_pipeline,preprocessing_fn=preprocess_input, batch_size=BATCH_SIZE,frame_length=FRAME_LENGTH,is_validation=True)

模型实例化后,我执行

model.fit(train_datagen,epochs=600,verbose=1,
          validation_data=(val_datagen.images,val_datagen.labels),validation_batch_size=BATCH_SIZE,
          callbacks=[checkpointer,StoppingFromFile()],validation_freq=1)

很好的解决方案,你有可以分享的代码吗? - TomSelleck
1
谢谢,我添加了一些内容,希望它能帮到你或者提供有用的想法。 - Anatoly Alekseev

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接