[TensorFlow] 关于tf.keras.layers.Embedding中参数input_length的作用

本文基于SO的帖子:Link: https://*.com/questions/61848825/why-is-input-length-needed-in-layers-embedding-in-keras-tensorflow

在翻文档的时候,发现了input_length这个参数,不知道有什么用。文档里的注释是:

input_length : Length of input sequences, when it is constant. This argument is required if you are going to connect Flatten then Dense layers upstream (without it, the shape of the dense outputs cannot be computed).

SO原提问是说,为什么我们在连接Flatten and Dense的时候不指定input_length时就会计算不出shape呢?

答案大致如下:

input_length这个参数是用来指定输入的长度,相当于在Embedding层之前加了一个Input层。 我们在连接Flatten以及Dense之前,我们需要这么一个input的长度,是因为Dense这一步是一个全连接,因此所有的长度必须指定,不然的话Dense与Flatten之间的W矩阵形状无法确定。

概念dim & length

列表: dim = 1, 矩阵: dim = 2。 length可以理解为batch的大小,即有多少个样本。

上一篇:CSPNET: A NEW BACKBONE THAT CAN ENHANCE LEARNING CAPABILITY OF CNN


下一篇:网络容器