Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable warning when using with_format format on tensors #7088

Open
Haislich opened this issue Aug 5, 2024 · 0 comments
Open

Disable warning when using with_format format on tensors #7088

Haislich opened this issue Aug 5, 2024 · 0 comments
Labels
enhancement New feature or request

Comments

@Haislich
Copy link

Haislich commented Aug 5, 2024

Feature request

If we write this code:

"""Get data and define datasets."""

from enum import StrEnum
from datasets import load_dataset
from torch.utils.data import DataLoader
from torchvision import transforms


class Split(StrEnum):
    """Describes what type of split to use in the dataloader"""

    TRAIN = "train"
    TEST = "test"
    VAL = "validation"


class ImageNetDataLoader(DataLoader):
    """Create an ImageNetDataloader"""

    _preprocess_transform = transforms.Compose(
        [
            transforms.Resize(256),
            transforms.CenterCrop(224),
        ]
    )

    def __init__(self, batch_size: int = 4, split: Split = Split.TRAIN):
        dataset = (
            load_dataset(
                "imagenet-1k",
                split=split,
                trust_remote_code=True,
                streaming=True,
            )
            .with_format("torch")
            .map(self._preprocess)
        )

        super().__init__(dataset=dataset, batch_size=batch_size)

    def _preprocess(self, data):
        if data["image"].shape[0] < 3:
            data["image"] = data["image"].repeat(3, 1, 1)
        data["image"] = self._preprocess_transform(data["image"].float())
        return data


if __name__ == "__main__":

    dataloader = ImageNetDataLoader(batch_size=2)
    for batch in dataloader:
        print(batch["image"])
        break

This will trigger an user warning :

datasets\formatting\torch_formatter.py:85: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  return torch.tensor(value, **{**default_dtype, **self.torch_tensor_kwargs})

Motivation

This happens because the the way the formatted tensor is returned in TorchFormatter._tensorize.
This function handle values of different types, according to some tests it seems that possible value types are int, numpy.ndarray and torch.Tensor.
In particular this warning is triggered when the value type is torch.Tensor, because is not the suggested Pytorch way of doing it:

Your contribution

A solution that I found to be working is to change the current way of doing it:

return torch.tensor(value, **{**default_dtype, **self.torch_tensor_kwargs})

To:

if (isinstance(value, torch.Tensor)):
    tensor = value.clone().detach()
    if self.torch_tensor_kwargs.get('requires_grad', False): 
        tensor.requires_grad_()
    return tensor
else:
    return torch.tensor(value, **{**default_dtype, **self.torch_tensor_kwargs})
@Haislich Haislich added the enhancement New feature or request label Aug 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
1 participant