Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom Dataset class on Chapter_1 #6

Open
LucianoBatista opened this issue Apr 20, 2022 · 9 comments
Open

Custom Dataset class on Chapter_1 #6

LucianoBatista opened this issue Apr 20, 2022 · 9 comments

Comments

@LucianoBatista
Copy link

LucianoBatista commented Apr 20, 2022

Hi, I was getting the followed erro when I executing this code:

from torch.utils.data import Dataset
from sklearn.datasets import fetch_openml

X, y = fetch_openml("mnist_784", version=1, return_X_y=True)

class SimpleDataset(Dataset):
    def __init__(self, X, y):
        super(SimpleDataset, self).__init__()
        self.X = X
        self.y = y
    
    def __getitem__(self, index):
        inputs = torch.tensor(self.X[index, :], dtype=torch.float32)
        targets = torch.tensor(int(self.y[index]), dtype=torch.int64)
        return inputs, targets

    def __len__(self):
        return self.X.shape[0]

dataset = SimpleDataset(X, y)
example, label = dataset[0]
InvalidIndexError: (tensor(0), slice(None, None, None))

The same was fixed when I change the code of the fetch_openml to:

X, y = fetch_openml("mnist_784", version=1, return_X_y=True, as_frame=False)

The problem was that whithout the as_frame, scikit will import the data as a DataFrame, not as numpy anymore.

@EdwardRaff
Copy link
Owner

Do you know if this was a change in scikit learn recently?

@LucianoBatista
Copy link
Author

LucianoBatista commented Apr 20, 2022

I don't know, currently, I'm using the newer (1.0.2) during the study of your the book:
https://pypi.org/project/scikit-learn/

@EdwardRaff
Copy link
Owner

Changed in version 0.24: The default value of as_frame changed from False to 'auto' in 0.24.

Yup, they changed it. I'll fix this soon - thank you for the catch!

@ChowZH
Copy link

ChowZH commented Jul 23, 2022

Changing the getitem function by adding .loc after the dataframe seems to solve the issue. Still working my way past this point but atleast I now got the same output as the book.

def getitem(self, index):
inputs = torch.tensor(self.X.loc[index,:], dtype=torch.float32)
targets = torch.tensor(int(self.y.loc[index]), dtype=torch.int64)
return inputs, targets

@chuymtz
Copy link

chuymtz commented Jul 27, 2022

I encountered this today and made a comment in the livebook from manning. I think this solution is better than what i suggested there.

If i change self.X = X.to_numpy() does that casuse problems down the road about being able to benefit from pytorch? Great book im really enjoying it.

@murphyk
Copy link

murphyk commented Jan 5, 2023

I used dataset = SimpleDataset(X.to_numpy(), y.to_numpy()) and it solved the problem.
This could go into the __init__ function as @chuymtz suggested, but
if we want to make a dataset from a torch tensor, we would have to use
dataset = SimpleDataset(X.numpy(), y.numpy()) instead (not to_numpy - argh).


XX = torch.tensor(X.to_numpy())
yy = torch.tensor(y.to_numpy(dtype=int), dtype=torch.int64)
dataset2 = SimpleDataset(XX.numpy(), yy.numpy())

@yebangyu
Copy link

yebangyu commented May 27, 2023

a little tedious for me about murphyk's solution since it converts X to tensor and then back to numpy again.

'X, y = fetch_openml("mnist_784", version=1, return_X_y=True, as_frame=False)'

this solution is very good. thanks, LucianoBatista

@yebangyu
Copy link

yebangyu commented May 29, 2023

another solution seems good, just fyi

X, y = fetch_openml("mnist_784", version=1, return_X_y=True)  # no need to change here
class SimpleDataset(Dataset):
    def __init__(self, X, y):
        super(SimpleDataset, self).__init__()
        self.X = X.values  # get the numpy data via values
        self.y = y.values # get the numpy data via values

dataset = SimpleDataset(X, y)
example, label = dataset[0]

@benjwolff
Copy link

@LucianoBatista: Thanks for your solution! I ran into the same problem. I created a pull request.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants