# Imports here import torch import torchvision.transforms as tf import torchvision.datasets as ds import torchvision.models as models from torch import nn import matplotlib.pyplot as plt import numpy as np from PIL import Image from neural_net import Neural_Network as net # TODO: Define your transforms for the training, validation, and testing sets # TODO: Load the datasets with ImageFolder # TODO: Using the image datasets and the trainforms, define the dataloaders def generate_datasets(params_dict, types_list, resize = 300, crop_size = 224): ''' Generators and data manipulation. Generates the required data transformations for us to train properly. Args: params_dict (dict): The nested dictionary containing the 'dir', 'batch' and 'shuffle' data. types_list (list of str): The list of param_dict keys, 'train', 'validate', 'test'. resize (int): The value to resize the image to. crop_size (int): The value we want to crop the image to Raises: TODO: Add exceptions Returns: datasets, dataloaders (tuple): The datasets and data loaders ''' # Define the transforms transforms = {} for t in types_list: transform_list = [] transform_list.append(tf.Resize(resize)) transform_list.append(tf.CenterCrop(crop_size)) transform_list.append(tf.ToTensor()) transform_list.append(tf.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])) if t == 'train': transform_list.pop(1) transform_list.insert(1, tf.RandomResizedCrop(crop_size)) transform_list.insert(2, tf.RandomHorizontalFlip()) transforms[t] = tf.Compose(transform_list) # Load the data sets, use dict comprehension to generate key vals for each type datasets = {i: ds.ImageFolder(params_dict[i]['dir'], transforms[i]) for i in types_list} # Define the loaders using the datasets and the transforms dataloaders = {i: torch.utils.data.DataLoader(datasets[i], params_dict[i]['batch'], params_dict[i]['shuffle']) for i in types_list} return datasets, dataloaders data_dir = 'flowers' train_dir = data_dir + '/train' valid_dir = data_dir + '/valid' test_dir = data_dir + '/test' # generate datasets and loaders params_dict = {'train': {'dir': train_dir, 'batch': 64, 'shuffle': True}, 'validate':{'dir': valid_dir, 'batch': 64, 'shuffle': True}, 'test':{'dir': test_dir, 'batch': 64, 'shuffle': False}} datasets, dataloaders = generate_datasets(params_dict, list(params_dict.keys()))