logo

classifier-data

# Imports here
import torch
import torchvision.transforms as tf
import torchvision.datasets as ds
import torchvision.models as models
from torch import nn
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image

from neural_net import Neural_Network as net


# TODO: Define your transforms for the training, validation, and testing sets
# TODO: Load the datasets with ImageFolder
# TODO: Using the image datasets and the trainforms, define the dataloaders

def generate_datasets(params_dict, types_list, resize = 300, crop_size = 224):

    ''' Generators and data manipulation. Generates the required data transformations 
    for us to train properly. 

    Args:
        params_dict (dict): The nested dictionary containing the 'dir', 'batch' and 'shuffle' data.
        types_list (list of str): The list of param_dict keys, 'train', 'validate', 'test'.
        resize (int): The value to resize the image to.
        crop_size (int): The value we want to crop the image to

    Raises:
        TODO: Add exceptions

    Returns:
        datasets, dataloaders (tuple): The datasets and data loaders 
    '''

    # Define the transforms
    transforms = {}
    for t in types_list:

        transform_list = []
        transform_list.append(tf.Resize(resize))
        transform_list.append(tf.CenterCrop(crop_size))
        transform_list.append(tf.ToTensor())
        transform_list.append(tf.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]))

        if t == 'train':
            transform_list.pop(1)
            transform_list.insert(1, tf.RandomResizedCrop(crop_size))
            transform_list.insert(2, tf.RandomHorizontalFlip())

        transforms[t] = tf.Compose(transform_list)

    # Load the data sets, use dict comprehension to generate key vals for each type
    datasets = {i: ds.ImageFolder(params_dict[i]['dir'], 
                                transforms[i]) for i in types_list}
    # Define the loaders using the datasets and the transforms
    dataloaders = {i: torch.utils.data.DataLoader(datasets[i], 
                                                params_dict[i]['batch'], 
                                                params_dict[i]['shuffle']) 
                                                for i in types_list}

    return datasets, dataloaders
    

data_dir = 'flowers'
train_dir = data_dir + '/train'
valid_dir = data_dir + '/valid'
test_dir = data_dir + '/test'


# generate datasets and loaders
params_dict = {'train': {'dir': train_dir, 'batch': 64, 'shuffle': True},
               'validate':{'dir': valid_dir, 'batch': 64, 'shuffle': True},
               'test':{'dir': test_dir, 'batch': 64, 'shuffle': False}}

datasets, dataloaders = generate_datasets(params_dict, list(params_dict.keys()))
  • Share