# TODO: Write a function that loads a checkpoint and rebuilds the model
def load_neural_net(filepath, mode = 'train'):
'''
Generates a model from torchvision, and instatiates a new Neural_Network instance which sets new model
as the active model. A new optimizer and criterion are also generated and assigned to the class properties.
Args:
file_path (str): The full path to the checkpoint
mode (str): Mode to set the model to ('train', 'eval')
Raises:
TODO: Update exceptions with error_handling class.
Returns:
net (nn_model.Neural_Network): An instance of the Neural_Network class with the loeaded model
as its model, parameters, criterion and optimizer.
'''
print('loading_net')
#TODO: Path validation
checkpoint = torch.load(filepath)
# Set Params
inputs = checkpoint['data']['input_count']
hidden_layers = checkpoint['data']['hidden_sizes']
outputs = checkpoint['data']['outputs']
activation = checkpoint['data']['h_activation']
dropout = checkpoint['data']['dropout']
learn_rate = checkpoint['data']['learn_rate']
device = checkpoint['device']
model = checkpoint['model']
model.load_state_dict(checkpoint['state_dict'])
# Make Network
net = Neural_Network(inputs, hidden_layers, outputs, activation, device, dropout, learn_rate)
net.model = model
net.epochs_completed = checkpoint['data']['epochs_completed']
if mode == 'train':
net.model.train()
elif mode == 'eval':
net.model.eval()
else:
raise ValueError('Error mode needs to be either train or eval')
net.model.classifier.class_to_idx = checkpoint['class_to_idx']
optimizer = torch.optim.Adam(net.model.classifier.parameters(), learn_rate)
optimizer.load_state_dict(checkpoint['optimizer.state_dict'])
criterion = nn.NLLLoss()
net.optimizer = optimizer
net.criterion = criterion
# Move to processing device
net.model.to(device)
return net
# load the model
loaded_net = load_neural_net('checkpoint_1.pth', 'eval')