Toccata in Nowhere.

TensorFlow Tensor张量堆叠 stack / 拼接 concat / 拆分 unstack

2020.07.18

在进行多目标联合训练时,除了共享一部分网络结构的输出外,也可以使用拆分张量的方法实现网络中张量流的拆分。相反,在联合训练时,通过合并不同层而共享同一个 loss function也可以为网络结构的搭建带来便利,以下介绍网络中张量拼接 / 拆分操作的几种方法:

按维度拼接 concat

tf.concat([a, b], axis=0)

顾名思义,类似于 np.concatenate,将[a,b]两个张量由 axis 方向合并。

Tips: 在进行tensor拼接时,因为数据来源的不同,可能出现因为数据类型不同无法拼接的情况,可以使用:

tensor = tf.convert_to_tensor(data, dtype=tf.float32)

对数据类型进行指定转换,之后再进行拼接操作。

此外,对于不同的图片训练集或其他同类训练集,拼接往往发生在第 $0$ 个维度,即tf.concat([a, b], axis=0)

堆叠 stack

tf.stack([a, b], axis=0)

tf.concat 的指定维度扩增不同,tf.stack会在 axis 位置新增一个维度,从而实现两个张量的拼接。

拆分 unstack

tf.unstack(a, axis=0)

tf.stack 相反的,tf.unstack 则是将张量 a 按照 axis 方向拆分,返回值大多为多个张量,注意接收。