更新
Keras-Core - Keras Core是Keras代码库的完全重写,它基于模块化后端架构进行了重新定位。它使得可以在任意框架上运行Keras工作流程,包括TensorFlow、JAX和PyTorch。
... 您可以在Keras Core + TensorFlow
模型上使用PyTorch DataLoader
进行训练,或者在Keras Core + PyTorch
模型上使用tf.data.Dataset
进行训练。
这是一个晚回答,但可能对未来的读者有所帮助。虽然这不是一个很好的解决方案,但目前正在进行一些讨论,可能会有所帮助,例如
链接1,
链接2。
为了演示,我们首先将构建一个
tf.data
加载器,并将其转换为
torch.utils.data
加载器。
构建
tf.data
BATCH_SIZE = 64
(x_train, y_train), _ = keras.datasets.cifar10.load_data()
def normalize(image, label, denorm=False):
rescale = keras.layers.Rescaling(scale=1./255.)
norms = keras.layers.Normalization(
mean=[0.4914, 0.4822, 0.4465],
variance=[np.square(0.2023), np.square(0.1994), np.square(0.2010)],
invert=denorm,
axis=-1,
)
if not denorm:
image = rescale(image)
return norms(image), label
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_ds = train_ds.map(normalize)
train_ds = train_ds.shuffle(buffer_size=8*BATCH_SIZE)
train_ds = train_ds.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
x, y = next(iter(train_ds))
x.shape, y.shape
(TensorShape([64, 32, 32, 3]), TensorShape([64, 1]))
构建
torch.utils.data
class TFDatasetWrapper(Dataset):
def __init__(self, tf_dataset):
self.tf_dataset = tf_dataset
def __len__(self):
return len(x_train)
def __getitem__(self, idx):
return next(iter(self.tf_dataset.skip(idx).take(1)))
def tf_collate_fn(batch):
x, y = zip(*batch)
x = torch.stack(x).permute(0, 3, 1, 2).type(torch.FloatTensor)
y = torch.stack(y)
return x, y
def iter_tf_data(train_ds):
x_list = []
y_list = []
for data in train_ds.as_numpy_iterator():
x, y = data
x_list += [torch.from_numpy(x)]
y_list += [torch.from_numpy(y)]
x_list_cat = torch.cat(x_list, axis=0)
y_list_cat = torch.cat(y_list, axis=0)
return [x_list_cat, y_list_cat]
def tf_dataset_to_pytorch_dataloader(
tf_dataset, batch_size, shuffle=True, num_workers=0
):
"""Converts a TensorFlow Dataset to a PyTorch DataLoader."""
data_list = iter_tf_data(tf_dataset)
pytorch_dataset = TensorDataset(*data_list)
pytorch_dataloader = DataLoader(
pytorch_dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
collate_fn=tf_collate_fn
)
return pytorch_dataloader
train_ds_torch = tf_dataset_to_pytorch_dataloader(
train_ds, batch_size=BATCH_SIZE // 2, shuffle=True
)
x, y = next(iter(train_ds_torch))
x.shape, y.shape
(torch.Size([32, 3, 32, 32]), torch.Size([32, 1]))
最后,让我们从
torch
加载器中可视化一些示例。
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
fig, ax = plt.subplots(figsize=(12, 6))
plt.title("CIFAR10 dataset")
ax.set_xticks([])
ax.set_yticks([])
ax.imshow(make_grid(x, nrow=8).permute(1, 2, 0))
plt.show()
![download](https://github.com/innat/DIP-In-Python/assets/17668390/7e4a157d-c230-4a8a-b937-bdcf2062ac98)