logo

classifier-load

# TODO: Write a function that loads a checkpoint and rebuilds the model

def load_neural_net(filepath, mode = 'train'):
    '''
    Generates a model from torchvision, and instatiates a new Neural_Network instance which sets new model 
    as the active model. A new optimizer and criterion are also generated and assigned to the class properties.

    Args:
        file_path (str): The full path to the checkpoint
        mode (str): Mode to set the model to ('train', 'eval')

    Raises:
        TODO: Update exceptions with error_handling class.

    Returns:
        net (nn_model.Neural_Network): An instance of the Neural_Network class with the loeaded model
            as its model, parameters, criterion and optimizer.
    '''        

    print('loading_net')
    #TODO: Path validation
    checkpoint = torch.load(filepath)
    # Set Params
    inputs = checkpoint['data']['input_count']
    hidden_layers = checkpoint['data']['hidden_sizes']
    outputs = checkpoint['data']['outputs']
    activation = checkpoint['data']['h_activation']
    dropout = checkpoint['data']['dropout']
    learn_rate = checkpoint['data']['learn_rate']
    device = checkpoint['device']
    model = checkpoint['model']
    model.load_state_dict(checkpoint['state_dict'])
    # Make Network
    net = Neural_Network(inputs, hidden_layers, outputs, activation, device, dropout, learn_rate)
    net.model = model
    net.epochs_completed = checkpoint['data']['epochs_completed']

    if mode == 'train':
        net.model.train()
    elif mode == 'eval':
        net.model.eval()
    else:
        raise ValueError('Error mode needs to be either train or eval')

    net.model.classifier.class_to_idx = checkpoint['class_to_idx']
    optimizer = torch.optim.Adam(net.model.classifier.parameters(), learn_rate)
    optimizer.load_state_dict(checkpoint['optimizer.state_dict'])
    criterion = nn.NLLLoss()
    net.optimizer = optimizer
    net.criterion = criterion
    # Move to processing device
    net.model.to(device)

    return net
    

# load the model
loaded_net = load_neural_net('checkpoint_1.pth', 'eval')
  • Share