在进行多目标联合训练时,除了共享一部分网络结构的输出外,也可以使用拆分张量的方法实现网络中张量流的拆分。相反,在联合训练时,通过合并不同层而共享同一个 loss function也可以为网络结构的搭建带来便利,以下介绍网络中张量拼接 / 拆分操作的几种方法:
按维度拼接 concat
tf.concat([a, b], axis=0)
顾名思义,类似于 np.concatenate
,将[a,b]
两个张量由 axis
方向合并。
Tips: 在进行tensor拼接时,因为数据来源的不同,可能出现因为数据类型不同无法拼接的情况,可以使用:
tensor = tf.convert_to_tensor(data, dtype=tf.float32)
对数据类型进行指定转换,之后再进行拼接操作。
此外,对于不同的图片训练集或其他同类训练集,拼接往往发生在第 $0$ 个维度,即tf.concat([a, b], axis=0)
。
堆叠 stack
tf.stack([a, b], axis=0)
与 tf.concat
的指定维度扩增不同,tf.stack
会在 axis
位置新增一个维度,从而实现两个张量的拼接。
拆分 unstack
tf.unstack(a, axis=0)
与 tf.stack
相反的,tf.unstack
则是将张量 a
按照 axis
方向拆分,返回值大多为多个张量,注意接收。