Octopod Ensemble

The ensemble aspects of Octopod are housed here. This includes sample model architectures, dataset class, and helper functions.

Model Architectures

class octopod.ensemble.models.multi_task_ensemble.BertResnetEnsembleForMultiTaskClassification(image_task_dict=None, dropout=0.1)

PyTorch ensemble class for multitask learning consisting of a text and image models

This model is made up of multiple component models: - for text: Google’s BERT model - for images: multiple ResNet50’s (the exact number depends on how the image model tasks were split up)

You may need to train the component image and text models first before combining them into an ensemble model to get good results.

Note: For explicitness, vanilla refers to the transformers BERT or PyTorch ResNet50 weights while pretrained refers to previously trained Octopod weights.

Examples

The ensemble model should be used with pretrained BERT and ResNet50 component models. To initialize a model in this way:

image_task_dict = {
    'color_pattern': {
        'color': color_train_df['labels'].nunique(),
        'pattern': pattern_train_df['labels'].nunique()
    },
    'dress_sleeve': {
        'dress_length': dl_train_df['labels'].nunique(),
        'sleeve_length': sl_train_df['labels'].nunique()
    },
    'season': {
        'season': season_train_df['labels'].nunique()
    }
}
model = BertResnetEnsembleForMultiTaskClassification(
    image_task_dict=image_task_dict
)

resnet_model_id_dict = {
    'color_pattern': 'SOME_RESNET_MODEL_ID1',
    'dress_sleeve': 'SOME_RESNET_MODEL_ID2',
    'season': 'SOME_RESNET_MODEL_ID3'
}

model.load_core_models(
    folder='SOME_FOLDER',
    bert_model_id='SOME_BERT_MODEL_ID',
    resnet_model_id_dict=resnet_model_id_dict
)

# DO SOME TRAINING

model.save(SOME_FOLDER, SOME_MODEL_ID)

# OR

model.export(SOME_FOLDER, SOME_MODEL_ID)
Parameters
  • image_task_dict (dict) – dictionary mapping each pretrained ResNet50 models to a dictionary of the tasks it was trained on

  • dropout (float) – dropout percentage for Dropout layer

static create_text_dict(image_task_dict)

Create a task dict for the text model from the image task dict

export(folder, model_id, model_name=None)

Exports the entire model state dict to a specific folder, along with the image_task_dict, which is needed to reinstantiate the model.

Parameters
  • folder (str or Path) – place to store state dictionaries

  • model_id (int) – unique id for this model

  • model_name (str (defaults to None)) – Name to store model under, if None, will default to multi_task_ensemble_{model_id}.pth

Side Effects

saves two files:
  • folder / f’multi_task_ensemble_{model_id}.pth’

  • folder / f’image_task_dict_{model_id}.pickle’

forward(x)

Defines forward pass for ensemble model

Parameters

x (dict) –

dictionary of torch tensors with keys:
  • bert_text: integers mapping to BERT vocabulary

  • full_img: tensor of full image

  • crop_img: tensor of cropped image

Returns

Return type

A dictionary mapping each task to its logits

freeze_bert()

Freeze all core BERT layers

freeze_classifiers_and_core()

Freeze pretrained classifier layers and core BERT/ResNet layers

freeze_ensemble_layers()

Freeze all final ensemble layers

freeze_resnets()

Freeze all core ResNet models layers

load(folder, model_id)

Loads the model state dicts for ensemble model from a specific folder. This will load all the model components including the final ensemble and existing pretrained classifiers.

Parameters
  • folder (str or Path) – place where state dictionaries are stored

  • model_id (int) – unique id for this model

Side Effects

loads from six files:
  • folder / f’bert_dict_{model_id}.pth’

  • folder / f’dropout_dict_{model_id}.pth’

  • folder / f’image_resnets_dict_{model_id}.pth’

  • folder / f’image_dense_layers_dict_{model_id}.pth’

  • folder / f’ensemble_layers_dict_{model_id}.pth’

  • folder / f’classifiers_dict_{model_id}.pth’

load_core_models(folder, bert_model_id, resnet_model_id_dict)

Loads the weights from pretrained BERT and ResNet50 Octopod models

Does not load weights from the final ensemble and classifier layers. use case is for loading SR_pretrained component BERT and image model weights into a new ensemble model.

Parameters
  • folder (str or Path) – place where state dictionaries are stored

  • bert_model_id (int) – unique id for pretrained BERT text model

  • resnet_model_id_dict (dict) –

    dict with unique id’s for pretrained image model, e.g. ``` resnet_model_id_dict = {

    ’task1_task2’: ‘model_id1’, ‘task3_task4’: ‘model_id2’, ‘task5’: ‘model_id3’

Side Effects

loads from four files:
  • folder / f’bert_dict_{bert_model_id}.pth’

  • folder / f’dropout_dict_{bert_model_id}.pth’

  • folder / f’resnet_dict_{resnet_model_id}.pth’

    for each resnet_model_id in the resnet_model_id_dict

  • folder / f’dense_layers_dict_{resnet_model_id}.pth’

save(folder, model_id)

Saves the model state dicts to a specific folder. Each part of the model is saved separately, along with the image_task_dict, which is needed to reinstantiate the model.

Parameters
  • folder (str or Path) – place to store state dictionaries

  • model_id (int) – unique id for this model

Side Effects

saves six files:
  • folder / f’bert_dict_{model_id}.pth’

  • folder / f’dropout_dict_{model_id}.pth’

  • folder / f’image_resnets_dict_{model_id}.pth’

  • folder / f’image_dense_layers_dict_{model_id}.pth’

  • folder / f’ensemble_layers_dict_{model_id}.pth’

  • folder / f’classifiers_dict_{model_id}.pth’

unfreeze_classifiers()

Unfreeze pretrained classifier layers

unfreeze_classifiers_and_core()

Unfreeze pretrained classifiers and core BERT/ResNet layers

Dataset

class octopod.ensemble.dataset.OctopodEnsembleDataset(text_inputs, img_inputs, y, tokenizer, max_seq_length=128, transform='train', crop_transform='train')

Load image and text data specifically for an ensemble model

Parameters
  • text_inputs (pandas Series) – the text to be used

  • img_inputs (pandas Series) – the paths to images to be used

  • y (list) – A list of dummy-encoded categories or strings, which will be encoded using a sklearn label encoder

  • tokenizer (pretrained BERT Tokenizer) – BERT tokenizer likely from transformers

  • max_seq_length (int (defaults to 128)) – Maximum number of tokens to allow

  • transform (str or list of PyTorch transforms) – specifies how to preprocess the full image for a Octopod image model To use the built-in Octopod image transforms, use the strings: train or val To use custom transformations supply a list of PyTorch transforms.

  • crop_transform (str or list of PyTorch transforms) – specifies how to preprocess the center cropped image for a Octopod image model To use the built-in Octopod image transforms, use strings train or val To use custom transformations supply a list of PyTorch transforms.

class octopod.ensemble.dataset.OctopodEnsembleDatasetMultiLabel(text_inputs, img_inputs, y, tokenizer, max_seq_length=128, transform='train', crop_transform='train')

Multi label subclass of OctopodEnsembleDataset

Parameters
  • text_inputs (pandas Series) – the text to be used

  • img_inputs (pandas Series) – the paths to images to be used

  • y (list) – a list of lists of binary encoded categories or strings with length equal to number of classes in the multi-label task. For a 4 class multi-label task a sample list would be [1,0,0,1], A string example would be [‘cat’,’dog’], (if the classes were [‘cat’,’frog’,’rabbit’,’dog]), which will be encoded using a sklearn label encoder to [1,0,0,1].

  • tokenizer (pretrained BERT Tokenizer) – BERT tokenizer likely from transformers

  • max_seq_length (int (defaults to 128)) – Maximum number of tokens to allow

  • transform (str or list of PyTorch transforms) – specifies how to preprocess the full image for a Octopod image model To use the built-in Octopod image transforms, use the strings: train or val To use custom transformations supply a list of PyTorch transforms.

  • crop_transform (str or list of PyTorch transforms) – specifies how to preprocess the center cropped image for a Octopod image model To use the built-in Octopod image transforms, use strings train or val To use custom transformations supply a list of PyTorch transforms.