tensorflow2.0在训练数据集的时候,fit和fit_generator的使用

model.fit函数

fit(x=None, y=None, batch_size=None, epochs=1, verbose=1, callbacks=None,
    validation_split=0.0, validation_data=None, shuffle=True, class_weight=None,
    sample_weight=None, initial_epoch=0, steps_per_epoch=None,
    validation_steps=None, validation_batch_size=None, validation_freq=1,
    max_queue_size=10, workers=1, use_multiprocessing=False
)

为模型训练固定的批次(数据集上的迭代)。

参数:

参数 作用
x 输入数据.它可能是:Numpy数组(或类似数组的数组)或数组列表(如果模型具有多个输入)。TensorFlow张量或张量列表(如果模型具有多个输入)。如果模型已命名输入,则dict将输入名称映射到相应的数组/张量。一个tf.data数据集。应该返回(inputs, targets)或 的元组(inputs, targets, sample_weights)。生成器或keras.utils.Sequence返回(inputs, targets) 或(inputs, targets, sample_weights)。下面给出了迭代器类型(数据集,生成器,序列)的拆包行为的更详细描述。
y 目标数据。像输入数据一样x,它可以是Numpy数组或TensorFlow张量。它应该与x(您不能有Numpy输入和张量目标,或者相反)保持一致。如果x是keras.utils.Sequence,y则不应该指定,生成器或实例(因为将从中获取目标x)。
batch_size 整数或None。每个梯度更新的样本数。如果未指定,则默认为32。如果数据是以数据集,生成器或实例的形式(因为它们生成批次),则不要指定。 batch_sizebatch_sizekeras.utils.Sequence
epochs 整数。训练模型的时期数。时期是整个x和所y 提供数据的迭代。请注意,在与结合, 应理解为“最后时期”。不会针对给出的多次迭代训练模型,而只是对到达索引的时期进行训练。 initial_epochepochsepochsepochs
verbose 0、1或2。详细模式。0 =静音,1 =进度条,2 =每个时期一行。请注意,进度条在登录到文件时不是特别有用,因此,如果不以交互方式运行(例如,在生产环境中),建议使用verbose = 2。
callbacks keras.callbacks.Callback实例 列表。训练期间要应用的回调列表。请参阅tf.keras.callbacks。
validation_split 在0到1之间浮动。将训练数据的分数用作验证数据。模型将分开训练数据的这一部分,不对其进行训练,并且将在每个时期结束时评估此数据的损失和任何模型度量。在改组之前,从x和中y提供的最后一个样本中选择验证数据。当x是数据集,生成器或keras.utils.Sequence实例时, 不支持此参数。
validation_data 在每个时期结束时用于评估损失的数据和任何模型指标。该模型将不会根据此数据进行训练。因此,请注意以下事实:使用 或不受正则化层(如噪声和压降)影响的数据验证损失。 将覆盖。 可能: validation_splitvalidation_datavalidation_datavalidation_splitvalidation_data 1.(x_val, y_val)Numpy数组或张量的元组. 2(x_val, y_val, val_sample_weights)Numpy数组的元组 .数据集对于前两种情况,batch_size必须提供。对于最后一种情况,validation_steps可以提供。请注意,validation_data它并不支持xdict,generator或中支持的所有数据类型keras.utils.Sequence。
shuffle 布尔值(是否在每个纪元之前改组训练数据)或str(用于“批处理”)。当x是生成器时,将忽略此参数。“批处理”是处理HDF5数据限制的特殊选项;它以批量大小的块洗牌。当没有任何效果是不是。 steps_per_epochNone
class_weight 可选的字典映射类索引(整数)到权重(浮动)值,用于加权损失函数(仅在训练期间)。这可能有助于告诉模型“更多关注”来自代表性不足的类的样本。
sample_weight 训练样本的可选Numpy权重数组,用于加权损失函数(仅在训练过程中)。您可以传递长度与输入样本相同的平坦(1D)Numpy数组(权重和样本之间的1:1映射),或者对于时间数据,可以传递带有shape的2D数组 以应用每个样品的每个时间步均具有不同的权重。如果是数据集,生成器或 实例,而将sample_weights作为的第三个元素,则不支持此参数。 (samples, sequence_length)xkeras.utils.Sequencex
initial_epoch 整数。开始训练的时期(用于恢复以前的训练运行)。
steps_per_epoch 整数或None。声明一个纪元完成并开始下一个纪元之前的总步数(一批样品)。使用输入张量(例如TensorFlow数据张量)进行训练时,默认None值等于数据集中的样本数除以批次大小;如果无法确定,则默认为1。如果x是 tf.data数据集,并且’steps_per_epoch’为None,则该纪元将运行直到输入数据集用尽。传递无限重复的数据集时,必须指定 参数。数组输入不支持此参数。 steps_per_epoch
validation_steps 仅在提供时才相关,并且是数据集。在每个时期结束时执行验证时,在停止之前要绘制的步骤总数(样本批次)。如果“ validation_steps”为“无”,则验证将一直进行到数据集用尽。如果是无限重复的数据集,它将陷入无限循环。如果指定了“ validation_steps”,并且仅消耗了一部分数据集,则评估将在每个时期从数据集的开头开始。这样可以确保每次都使用相同的验证样本。 validation_datatf.datavalidation_data
validation_batch_size 整数或None。每个验证批次的样品数量。如果未指定,则默认为。不要指定数据是数据集,生成器还是实例的形式(因为它们生成批处理)。 batch_sizevalidation_batch_sizekeras.utils.Sequence
validation_freq 仅在提供验证数据时才相关。整数或实例(例如列表,元组等)。如果为整数,则指定在执行新的验证运行之前要运行多少个训练时期,例如,每2个时期运行一次验证。如果是容器,则指定要运行验证的时期,例如,在第一个,第二个和第十个时期的末尾运行验证。 collections_abc.Containervalidation_freq=2validation_freq=[1, 2, 10]
max_queue_size 整数。keras.utils.Sequence 仅用于生成器或输入。生成器队列的最大大小。如果未指定,则默认为10。 max_queue_size
workers 整数。keras.utils.Sequence仅用于生成器或输入。使用基于进程的线程时,要启动的最大进程数。如果未指定,workers 则默认为1。如果为0,将在主线程上执行生成器。
use_multiprocessing 布尔值。keras.utils.Sequence仅用于生成器或 输入。如果为True,请使用基于进程的线程。如果未指定,则默认为 。请注意,由于此实现依赖于多处理,因此不应将不可拾取的参数传递给生成器,因为它们无法轻易传递给子进程。 use_multiprocessingFalse

类似于迭代器的输入的拆包行为:一种常见的模式是将tf.data.Dataset,generator或tf.keras.utils.Sequence传递给fitx参数,这实际上不仅会产生特征(x),而且会产生可选结果目标(y)和样本权重。Keras要求此类类似迭代器的输出必须明确。迭代器应返回长度为1、2或3的元组,其中可选的第二和第三元素将分别用于y和sample_weight。提供的任何其他类型将被包裹在一个元组的长度中,从而将所有内容有效地视为“ x”。发出命令时,它们仍应遵循*元组结构。例如({“x0”: x0, “x1”: x1}, y)。Keras不会尝试从单个字典的键中分离特征,目标和权重。值得注意的不受支持的数据类型是namedtuple。原因是它的行为类似于有序数据类型(元组)和映射数据类型(dict)。因此,给定形式的namedtuple: namedtuple(“example_tuple”, [“y”, “x”]) 在解释值时是否反转元素的顺序是不明确的。更糟糕的是以下形式的元组: namedtuple(“other_tuple”, [“x”, “y”, “z”]) 尚不清楚该元组是否打算解包为x,y和sample_weight或作为单个元素传递给x。结果,如果数据处理代码遇到一个命名元组,它将仅引发ValueError。(以及纠正该问题的说明。)

函数返回:
历史记录对象。它的History.history属性记录了连续时期的训练损失值和度量值,以及验证损失值和验证度量值(如果适用)。

fit_generator函数

fit_generator(
    generator, steps_per_epoch=None, epochs=1, verbose=1, callbacks=None,
    validation_data=None, validation_steps=None, validation_freq=1,
    class_weight=None, max_queue_size=10, workers=1, use_multiprocessing=False,
    shuffle=True, initial_epoch=0
)

使模型适合Python生成器逐批生成的数据。

在使用tensorflow2.0.0版本的时候的,去运行了tensorflow2.2.0版本的yoloV4代码,导致报错出现:
tensorflow2.0在训练数据集的时候,fit和fit_generator的使用
tensorflow2.0在训练数据集的时候,fit和fit_generator的使用

TypeError: int() argument must be a string, a bytes-like object or a number, not ‘tuple’

原来是因为在2.0.0版本的时候Model.fit不支持生成器创建的数据集,因此会出现错误。
把Model.fit函数换成Model.fit_generator函数,则成功解决这个问题。

如果帮助到您,点赞论评走一波!!!!!
谢谢各位大佬!!!!

上一篇:OpenCV-Utils学习日志:GUI模块要点总结


下一篇:解决django.db.utils.OperationalError: (1045, “Access denied for user ‘Administrator‘@‘localhost‘ (us