将tf.dataset转换为PyTorch Dataset?

5
我正在处理一个项目,其中所有的数据都经过预处理,并以 TensorFlow 数据集的形式准备好,看起来像这样:

<MapDataset shapes: {input_ids: (128,), input_mask: (128,), label_ids: (), segment_ids: (128,)}, types: {input_ids: tf.int64, input_mask: tf.int64, label_ids: tf.int64, segment_ids: tf.int64}>

我手头上的脚本是用 PyTorch 编写的,它接受一个 Dataset 对象,该对象的格式如下:
Dataset({
    features: [
        'attention_mask', 
        'input_ids', 
        'label', 
        'sentence', 
        'token_type_ids'
    ],
    num_rows: 12
})

有没有办法将一个转换为另一个?我对这两个API都很陌生,不太确定它们的工作原理。我能否潜在地将一个转换为另一个?
2个回答

1

我使用tfds.as_numpy(dataset)作为我的模型训练的数据加载器。在我的模型的前向函数中,我使用torch.as_tensor(data, device=<device>)将传递给模型的数据转换为张量。

import tensorflow_datasets as tfds
import torch.nn as nn

def train_dataloader(batch_size):
    return tfds.as_numpy(tfds.load('mnist').batch(batch_size))

class Model(nn.Module):
    def forward(self, x):
        x = torch.as_tensor(x, device='cuda')
        ...

0

更新

Keras-Core - Keras Core是Keras代码库的完全重写,它基于模块化后端架构进行了重新定位。它使得可以在任意框架上运行Keras工作流程,包括TensorFlow、JAX和PyTorch。

... 您可以在Keras Core + TensorFlow模型上使用PyTorch DataLoader进行训练,或者在Keras Core + PyTorch模型上使用tf.data.Dataset进行训练。


这是一个晚回答,但可能对未来的读者有所帮助。虽然这不是一个很好的解决方案,但目前正在进行一些讨论,可能会有所帮助,例如链接1链接2
为了演示,我们首先将构建一个tf.data加载器,并将其转换为torch.utils.data加载器。
构建tf.data
BATCH_SIZE = 64
(x_train, y_train), _ = keras.datasets.cifar10.load_data()

def normalize(image, label, denorm=False):
    rescale = keras.layers.Rescaling(scale=1./255.)
    norms = keras.layers.Normalization(
        mean=[0.4914, 0.4822, 0.4465], 
        variance=[np.square(0.2023), np.square(0.1994), np.square(0.2010)], 
        invert=denorm,
        axis=-1,
    )
    if not denorm:
        image = rescale(image)
    return norms(image), label

train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_ds = train_ds.map(normalize)
train_ds = train_ds.shuffle(buffer_size=8*BATCH_SIZE)
train_ds = train_ds.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

x, y = next(iter(train_ds))
x.shape, y.shape
(TensorShape([64, 32, 32, 3]), TensorShape([64, 1]))

构建torch.utils.data
# Define a custom PyTorch dataset class
class TFDatasetWrapper(Dataset):
    def __init__(self, tf_dataset):
        self.tf_dataset = tf_dataset

    def __len__(self):
        return len(x_train)

    def __getitem__(self, idx):
        return next(iter(self.tf_dataset.skip(idx).take(1)))

def tf_collate_fn(batch):
    x, y = zip(*batch)
    x = torch.stack(x).permute(0, 3, 1, 2).type(torch.FloatTensor)
    y = torch.stack(y)
    return x, y

def iter_tf_data(train_ds):
    x_list = []
    y_list = []
    for data in train_ds.as_numpy_iterator():
        x, y = data
        x_list += [torch.from_numpy(x)] 
        y_list += [torch.from_numpy(y)]
    x_list_cat = torch.cat(x_list, axis=0)
    y_list_cat = torch.cat(y_list, axis=0)
    return [x_list_cat, y_list_cat]

def tf_dataset_to_pytorch_dataloader(
    tf_dataset, batch_size, shuffle=True, num_workers=0
):
    """Converts a TensorFlow Dataset to a PyTorch DataLoader."""
    data_list = iter_tf_data(tf_dataset)
    pytorch_dataset = TensorDataset(*data_list)
    pytorch_dataloader = DataLoader(
        pytorch_dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        collate_fn=tf_collate_fn
    )
    return pytorch_dataloader

train_ds_torch = tf_dataset_to_pytorch_dataloader(
    train_ds, batch_size=BATCH_SIZE // 2, shuffle=True
)
x, y = next(iter(train_ds_torch))
x.shape, y.shape
(torch.Size([32, 3, 32, 32]), torch.Size([32, 1]))

最后,让我们从torch加载器中可视化一些示例。
import matplotlib.pyplot as plt
from torchvision.utils import make_grid

fig, ax = plt.subplots(figsize=(12, 6))
plt.title("CIFAR10 dataset")
ax.set_xticks([])
ax.set_yticks([])
ax.imshow(make_grid(x, nrow=8).permute(1, 2, 0))
plt.show()

download


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接