Skip to content

Unable to use packed sequence for GridSearchCV #1083

@nafraw

Description

@nafraw

Dear Dev or whoever can help,

I have sequences of different lengths and therefore thinking of using packed sequence to feed them to RNN (or whatever similar). What I tried/figured out is that I need to use Dataset to pack sequences and unpack inside the RNN model to let fit() works, but when it is time to apply GridSearchCV (partition by grouping some sequences), there seems to be no way to work around. Codes below are what I have for now.

importing modules

import numpy as np
from imblearn.pipeline import Pipeline

from skorch import NeuralNetClassifier, NeuralNetRegressor, NeuralNet
from skorch.scoring import loss_scoring
from skorch.dataset import ValidSplit
from torch import nn
import torch
from torch.nn.utils.rnn import pack_sequence, unpack_sequence
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import GridSearchCV, StratifiedGroupKFold

Defining some classes

class ScoredRegressorNet(NeuralNetRegressor):
    def score(self, X, y=None):
        return -loss_scoring(self, X, y)

# Define a custom dataset
class SequenceDataset(Dataset):
    # sequence and labels are lists
    def __init__(self, sequence, labels):
        self.sequence = sequence
        self.labels = labels

    def __len__(self):
        return len(self.sequence)

    def __getitem__(self, idx):
        return self.sequence[idx], self.labels[idx]

class MyRNN(nn.Module):
    def __init__(self, input_size, **rnn_kwargs):
        super(MyRNN, self).__init__()
        self.rnn = nn.RNN(input_size = input_size, **rnn_kwargs, batch_first=True)
        self.hidden_state = None

    def forward(self, x):
        y, hs = self.rnn(x)
        # w/o unpack and concat, skorch cannot compute loss during the fit_loop
        y = unpack_sequence(y)
        y = torch.concat(y)
        return y

Generate random data

n_seq = 4
nFeatures = 2
seq = [None]*n_seq
gt = [None]*n_seq
high = 5
for i in range(n_seq):
    seq[i] = torch.rand((3*(n_seq+1-i), nFeatures), dtype=torch.float64) # generate sequence with different lengths
    gt[i] = torch.randint_like(seq[i][:,0], high, dtype=torch.float64)
dataset = SequenceDataset(seq, gt)

Define NeuralNet

mdl = MyRNN(input_size=nFeatures, hidden_size=1, bidirectional=False).double()

def my_collate_fn(data):
    # nSeq = len(data)
    seq = pack_sequence([d[0] for d in data])
    lab = [d[1] for d in data]
    lab = torch.concat(lab)
    return (seq, lab)

mynet = ScoredRegressorNet(
    mdl,
    max_epochs=5,
    lr=1e-3,
    batch_size = 2000,
    device="cpu",
    train_split=ValidSplit(2),
    iterator_train__shuffle=False,
    iterator_train__batch_size = 2,
    iterator_train__collate_fn=my_collate_fn,
    iterator_valid__collate_fn=my_collate_fn,
    criterion=nn.SmoothL1Loss,
)

Fitting data

mynet.fit(dataset, y=None) # this runs


nGroups = n_seq
NGroupHoldOut = 1
n_splits=np.ceil(nGroups/NGroupHoldOut).astype(np.int32)
cross_validate = StratifiedGroupKFold(n_splits=n_splits)
pipe = Pipeline([('net', mynet)])
gs = GridSearchCV(estimator=pipe, cv=cross_validate, 
                  param_grid={},
                  refit=False, n_jobs=1
               )
# None of below works
gs.fit(dataset, y=None, 
       groups=[x for x in range(n_seq)])
gs.fit(dataset, y=gt, 
       groups=[x for x in range(n_seq)])


If I feed y as None, there is an error from StratifiedGroupKFold
"ValueError: Supported target types are: ('binary', 'multiclass'). Got 'unknown' instead." because y being fed to StratifiedGroupKFold is None.

If I feed y as gt, StratifiedGroupKFold will also fail because gt is a list, and if I concatenate gt, it will not work as well because lengths of X and y are not consistent.

To do another test with default cv, I tried

gs = GridSearchCV(estimator=pipe, cv=2, 
                  param_grid={},
                  refit=False, n_jobs=1
               )
gs.fit(dataset, y=None, 
       groups=[x for x in range(n_seq)])

It returns
"ValueError: No y-values are given (y=None). You must implement your own DataLoader for training (and your validation) and supply it using the iterator_train and iterator_valid parameters respectively."

At this point, I am not sure why what is the workaround, as I do not fully understand skorch. Does it work with packed sequence in any other forms for what I need?

Thank you in advance!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions