如何将 numpy 转换为 tfrecords 然后生成批次?

2025-03-21 09:06:00
admin
原创
68
摘要:问题描述:我的问题是如何从多个(或分片)tfrecords 获取批量输入。我已阅读示例https://github.com/tensorflow/models/blob/master/inception/inception/image_processing.py#L410。基本流程是,以训练集为例,(1) 首...

问题描述:

我的问题是如何从多个(或分片)tfrecords 获取批量输入。我已阅读示例https://github.com/tensorflow/models/blob/master/inception/inception/image_processing.py#L410。基本流程是,以训练集为例,(1) 首先生成一系列 tfrecords(例如,,,train-000-of-005... train-001-of-005),(2) 从这些文件名生成一个列表并将它们输入到tf.train.string_input_producer队列中,(3) 同时生成一个tf.RandomShuffleQueue来做其他事情,(4) 用来tf.train.batch_join生成批量输入。

我认为这很复杂,我不确定这个过程的逻辑。就我而言,我有一个.npy文件列表,我想生成分片的 tfrecords(多个独立的 tfrecords,而不仅仅是一个大文件)。每个文件都.npy包含不同数量的正样本和负样本(2 个类)。一个基本方法是生成一个单独的大型 tfrecord 文件。但文件太大(~20Gb)。所以我求助于分片的 tfrecords。有没有更简单的方法可以做到这一点?


解决方案 1:

整个过程使用 来简化Dataset API。以下是两个部分:(1): Convert numpy array to tfrecords(2): read the tfrecords to generate batches

1.从 numpy 数组创建 tfrecords:

Example arrays:
inputs = np.random.normal(size=(5, 32, 32, 3))
labels = np.random.randint(0,2,size=(5,))

def npy_to_tfrecords(inputs, labels, filename):
  with tf.io.TFRecordWriter(filename) as writer:
    for X, y in zip(inputs, labels):
        # Feature contains a map of string to feature proto objects
        feature = {}
        feature['X'] = tf.train.Feature(float_list=tf.train.FloatList(value=X.flatten()))
        feature['y'] = tf.train.Feature(int64_list=tf.train.Int64List(value=[y]))

        # Construct the Example proto object
        example = tf.train.Example(features=tf.train.Features(feature=feature))

        # Serialize the example to a string
        serialized = example.SerializeToString()

        # write the serialized objec to the disk
        writer.write(serialized)


npy_to_tfrecords(inputs, labels, 'numpy.tfrecord')

2.使用 Dataset API 读取 tfrecords:

filenames = ['numpy.tfrecord']
dataset = tf.data.TFRecordDataset(filenames)
# for version 1.5 and above use tf.data.TFRecordDataset

# example proto decode
def _parse_function(example_proto):
    keys_to_features = {'X':tf.io.FixedLenFeature(shape=(32, 32, 3), dtype=tf.float32),
                      'y': tf.io.FixedLenFeature((), tf.int64, default_value=0)}
    parsed_features = tf.io.parse_single_example(example_proto, keys_to_features)
    return parsed_features['X'], parsed_features['y']

# Parse the record into tensors.
dataset = dataset.map(_parse_function)  
  
# Generate batches
dataset = dataset.batch(5)

检查生成的批次是否正确:

for data in dataset:
    break
np.testing.assert_allclose(inputs[0] ,data[0][0])
np.testing.assert_allclose(labels[0] ,data[1][0])
相关推荐
  政府信创国产化的10大政策解读一、信创国产化的背景与意义信创国产化,即信息技术应用创新国产化,是当前中国信息技术领域的一个重要发展方向。其核心在于通过自主研发和创新,实现信息技术应用的自主可控,减少对外部技术的依赖,并规避潜在的技术制裁和风险。随着全球信息技术竞争的加剧,以及某些国家对中国在科技领域的打压,信创国产化显...
工程项目管理   2482  
  为什么项目管理通常仍然耗时且低效?您是否还在反复更新电子表格、淹没在便利贴中并参加每周更新会议?这确实是耗费时间和精力。借助软件工具的帮助,您可以一目了然地全面了解您的项目。如今,国内外有足够多优秀的项目管理软件可以帮助您掌控每个项目。什么是项目管理软件?项目管理软件是广泛行业用于项目规划、资源分配和调度的软件。它使项...
项目管理软件   1533  
  PLM(产品生命周期管理)项目对于企业优化产品研发流程、提升产品质量以及增强市场竞争力具有至关重要的意义。然而,在项目推进过程中,范围蔓延是一个常见且棘手的问题,它可能导致项目进度延迟、成本超支以及质量下降等一系列不良后果。因此,有效避免PLM项目范围蔓延成为项目成功的关键因素之一。以下将详细阐述三大管控策略,助力企业...
plm系统   0  
  PLM(产品生命周期管理)项目管理在企业产品研发与管理过程中扮演着至关重要的角色。随着市场竞争的加剧和产品复杂度的提升,PLM项目面临着诸多风险。准确量化风险优先级并采取有效措施应对,是确保项目成功的关键。五维评估矩阵作为一种有效的风险评估工具,能帮助项目管理者全面、系统地评估风险,为决策提供有力支持。五维评估矩阵概述...
免费plm软件   0  
  引言PLM(产品生命周期管理)开发流程对于企业产品的全生命周期管控至关重要。它涵盖了从产品概念设计到退役的各个阶段,直接影响着产品质量、开发周期以及企业的市场竞争力。在当今快速发展的科技环境下,客户对产品质量的要求日益提高,市场竞争也愈发激烈,这就使得优化PLM开发流程成为企业的必然选择。缺陷管理工具和六西格玛方法作为...
plm产品全生命周期管理   0  
热门文章
项目管理软件有哪些?
曾咪二维码

扫码咨询,免费领取项目管理大礼包!

云禅道AD
禅道项目管理软件

云端的项目管理软件

尊享禅道项目软件收费版功能

无需维护,随时随地协同办公

内置subversion和git源码管理

每天备份,随时转为私有部署

免费试用