You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
"""Get data and define datasets."""fromenumimportStrEnumfromdatasetsimportload_datasetfromtorch.utils.dataimportDataLoaderfromtorchvisionimporttransformsclassSplit(StrEnum):
"""Describes what type of split to use in the dataloader"""TRAIN="train"TEST="test"VAL="validation"classImageNetDataLoader(DataLoader):
"""Create an ImageNetDataloader"""_preprocess_transform=transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
]
)
def__init__(self, batch_size: int=4, split: Split=Split.TRAIN):
dataset= (
load_dataset(
"imagenet-1k",
split=split,
trust_remote_code=True,
streaming=True,
)
.with_format("torch")
.map(self._preprocess)
)
super().__init__(dataset=dataset, batch_size=batch_size)
def_preprocess(self, data):
ifdata["image"].shape[0] <3:
data["image"] =data["image"].repeat(3, 1, 1)
data["image"] =self._preprocess_transform(data["image"].float())
returndataif__name__=="__main__":
dataloader=ImageNetDataLoader(batch_size=2)
forbatchindataloader:
print(batch["image"])
break
This will trigger an user warning :
datasets\formatting\torch_formatter.py:85: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
return torch.tensor(value, **{**default_dtype, **self.torch_tensor_kwargs})
Motivation
This happens because the the way the formatted tensor is returned in TorchFormatter._tensorize.
This function handle values of different types, according to some tests it seems that possible value types are int, numpy.ndarray and torch.Tensor.
In particular this warning is triggered when the value type is torch.Tensor, because is not the suggested Pytorch way of doing it:
Feature request
If we write this code:
This will trigger an user warning :
Motivation
This happens because the the way the formatted tensor is returned in
TorchFormatter._tensorize
.This function handle values of different types, according to some tests it seems that possible value types are
int
,numpy.ndarray
andtorch.Tensor
.In particular this warning is triggered when the value type is
torch.Tensor
, because is not the suggested Pytorch way of doing it:Your contribution
A solution that I found to be working is to change the current way of doing it:
To:
The text was updated successfully, but these errors were encountered: