# Imports here
import torch
import torchvision.transforms as tf
import torchvision.datasets as ds
import torchvision.models as models
from torch import nn
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from neural_net import Neural_Network as net
# TODO: Define your transforms for the training, validation, and testing sets
# TODO: Load the datasets with ImageFolder
# TODO: Using the image datasets and the trainforms, define the dataloaders
def generate_datasets(params_dict, types_list, resize = 300, crop_size = 224):
''' Generators and data manipulation. Generates the required data transformations
for us to train properly.
Args:
params_dict (dict): The nested dictionary containing the 'dir', 'batch' and 'shuffle' data.
types_list (list of str): The list of param_dict keys, 'train', 'validate', 'test'.
resize (int): The value to resize the image to.
crop_size (int): The value we want to crop the image to
Raises:
TODO: Add exceptions
Returns:
datasets, dataloaders (tuple): The datasets and data loaders
'''
# Define the transforms
transforms = {}
for t in types_list:
transform_list = []
transform_list.append(tf.Resize(resize))
transform_list.append(tf.CenterCrop(crop_size))
transform_list.append(tf.ToTensor())
transform_list.append(tf.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]))
if t == 'train':
transform_list.pop(1)
transform_list.insert(1, tf.RandomResizedCrop(crop_size))
transform_list.insert(2, tf.RandomHorizontalFlip())
transforms[t] = tf.Compose(transform_list)
# Load the data sets, use dict comprehension to generate key vals for each type
datasets = {i: ds.ImageFolder(params_dict[i]['dir'],
transforms[i]) for i in types_list}
# Define the loaders using the datasets and the transforms
dataloaders = {i: torch.utils.data.DataLoader(datasets[i],
params_dict[i]['batch'],
params_dict[i]['shuffle'])
for i in types_list}
return datasets, dataloaders
data_dir = 'flowers'
train_dir = data_dir + '/train'
valid_dir = data_dir + '/valid'
test_dir = data_dir + '/test'
# generate datasets and loaders
params_dict = {'train': {'dir': train_dir, 'batch': 64, 'shuffle': True},
'validate':{'dir': valid_dir, 'batch': 64, 'shuffle': True},
'test':{'dir': test_dir, 'batch': 64, 'shuffle': False}}
datasets, dataloaders = generate_datasets(params_dict, list(params_dict.keys()))