# Net Utilities to build from tv
def net_from_torchvision(hidden_sizes, outputs, hidden_activation, device,
optimizer_name = 'adam', dropout = 0.3, learn_rate = 0.002,
name = 'vgg16', trained = True):
'''
Generates a model from torchvision, and instatiates a new Neural_Network instance which sets new model
as the active model. A new optimizer and criterion are also generated and assigned to the class properties.
Args:
hidden_sizes (list of ints): The hidden layer sizes.
outputs (int): The number of outputs.
hidden_activation (str): The hidden layer activation functions (ex. relu, sigmoid, tahn).
device (str): The gpu or the cpu.
optimizer_name (str): The optimizer name ('sgd' or 'adam') to update the weights and gradients
dropout (float): The dropout rate, value to randomly drop input units through training.
learn_rate (float): The learning rate value, used along with the gradient to update the weights,
small values ensure that the weight update steps are small enough.
name (str): The pretrained model name ('vgg16', 'resnet50', 'densenet121').
trained (bool): If the model has been trained.
Raises:
TODO: Update exceptions with error_handling class.
Returns:
net (nn_model.Neural_Network): An instance of the Neural_Network class with the trained model
as its model and parameters.
'''
model = get_pretrained_model(name, trained)
feature_count = model.classifier[0].in_features
net = Neural_Network(feature_count, hidden_sizes, outputs,
hidden_activation, device, dropout, learn_rate)
model.classifier = net.model
net.model = model
if optimizer_name != 'adam' and optimizer_name != 'sgd':
raise ValueError('Please use either SDG or Adam as optimizers')
elif optimizer_name == 'adam':
net.optimizer = torch.optim.Adam(net.model.classifier.parameters(), learn_rate)
else:
net.optimizer = torch.optim.SDG(net.model.classifier.parameters(), learn_rate)
net.criterion = nn.NLLLoss()
return net
def get_pretrained_model(name = 'vgg16', trained = True):
'''Generates the nn.module container Sequential classfier as the default for this class.
Args:
name (str): The pretrained model name ('vgg16', 'resnet50', 'densenet121').
trained (bool): If the model has been trained.
Raises:
TODO: Update exceptions with error_handling class.
Returns:
model (torchvision.models.vgg.VGG): The torch vision model specified
'''
# get model from torchvision
if name == 'vgg16':
model = models.vgg16(pretrained = trained)
elif name == 'resnet50':
model = models.resnet50(pretrained = trained)
elif name == 'densenet121':
model = models.densenet121(pretrained = trained)
else:
raise ValueError('Please select from either vgg16, resnet50 or \
densenet121 pre-trained models')
# freeze parameters
for parameter in model.parameters():
parameter.requires_grad = False
return model