由于TensorFlow中tensor数据类型的特殊性,对它的处理往往是一件比较头疼的事情。有些情况需要将其转换为numpy array进行计算,这时有一个很有效的函数py_func,这里举一个使用例子,函数本身的用法可见参考资料。

augment = lambda x: dict(image=augment_shift(augment_mirror(x['image']), 4), label=x['label'])

augment是一个数据增强函数,augment_mirror和augment_shift都是使用TensorFlow函数实现的数据增强操作,此函数是用lambda表示的,为了便于后续修改,转换成常规的函数表示。

def augment(x):
    return dict(image=augment_shift(augment_mirror(x['image']), 4), label=x['label'])

引入tf.py_func函数,输入三个参数,第一个_aug是一个处理函数,处理numpy array类型数据。第二个参数是Tensor数据列表,这里只有一个Tensor数据,但也要以列表形式作为输入,第三个参数是经过_aug输出再转换成Tensor数据的类型。augment_shift_np和augment_mirror_np分别是augment_shift和augment_mirror的函数的numpy实现。_aug函数输入的image即为py_func第二个输入参数列表中的x[‘image’],前者为numpy array类型,后者为Tensor类型。

def augment(x):
    def _aug(image):
        return augment_shift_np(augment_mirror_np(image), 4)
    
    aug_image = tf.py_func(_aug, [x['image']], tf.float32)
    return dict(image=aug_image, label=x['label'])

参考:

tf.py_func官方文档

经验干货:使用tf.py_func函数增加Tensorflow程序的灵活性