import torch import torch.nn as nn import torch.optim as optim from torchvision.models import resnet50 from torchvision import transforms from torch.utils.data import DataLoader from tqdm import tqdm import time import os from os.path import abspath import matplotlib.pyplot as plt import config import dataset # projectRoot = "/src/" script_directory = os.path.dirname(__file__) projectRoot = abspath(os.path.join(script_directory, "../..")) print("Project root: " + projectRoot) INITIAL_PLOT_PATH = os.path.join(projectRoot, "src/output/plot.png") INTIIAL_MODEL_PATH = os.path.join(projectRoot, "src/output/model.pth") valDatasetPath = os.path.join(projectRoot, "src/autophotographer/valDataset.pt") trainDatasetPath = os.path.join(projectRoot, "src/autophotographer/trainDataset.pt") # define transformations trainTransform = transforms.Compose([ transforms.RandomResizedCrop(config.IMAGE_SIZE), transforms.RandomHorizontalFlip(), transforms.RandomRotation(90), transforms.ToTensor(), transforms.Normalize(mean=config.MEAN, std=config.STD) ]) valTransform = transforms.Compose([ transforms.Resize((config.IMAGE_SIZE, config.IMAGE_SIZE)), transforms.ToTensor(), transforms.Normalize(mean=config.MEAN, std=config.STD) ]) valSetLen = int(len(dataset.df) * config.VAL_SPLIT) trainSetLen = len(dataset.df) - valSetLen trainSet = dataset.df[:trainSetLen] valSet = dataset.df[trainSetLen:] print("Using " + config.DEVICE + "...") # create data loaders print("Getting dataloaders...") #(trainDataset, trainLoader) = dataset.get_dataloader(trainSet, #transforms=trainTransform, batchSize=config.FEATURE_EXTRACTION_BATCH_SIZE) #torch.save(trainDataset, 'trainDataset.pt') #(valDataset, valLoader) = dataset.get_dataloader(valSet, #transforms=valTransform, batchSize=config.FEATURE_EXTRACTION_BATCH_SIZE, shuffle=False) #torch.save(valDataset, 'valDataset.pt') valDataset = torch.load(valDatasetPath) valLoader = DataLoader(valDataset, batch_size=config.FEATURE_EXTRACTION_BATCH_SIZE, shuffle=False, num_workers=os.cpu_count(), pin_memory=True if config.DEVICE == "cuda" else False) trainDataset = torch.load(trainDatasetPath) trainLoader = DataLoader(trainDataset, batch_size=config.FEATURE_EXTRACTION_BATCH_SIZE, shuffle=True, num_workers=os.cpu_count(), pin_memory=True if config.DEVICE == "cuda" else False) # Load the resnet model model = resnet50(pretrained=True) # Freeze all existing layers for parameter in model.parameters(): parameter.requires_grad = False modelOutputFeatures = model.fc.in_features model.fc = nn.Linear(modelOutputFeatures, 1) model = model.to(config.DEVICE) # initialize loss function and optimizer lossFunction = nn.L1Loss() # mean absolute error optimizer = torch.optim.Adam(model.fc.parameters(), lr=config.LR) # calculate steps per epoch for training and validating set trainSteps = len(trainDataset) // config.FEATURE_EXTRACTION_BATCH_SIZE valSteps = len(valDataset) // config.FEATURE_EXTRACTION_BATCH_SIZE # initialize a dictionary to store training data H = {"train_loss": [], "val_loss": []} # loop over epochs print("Starting training...") startTime = time.time() for epoch in tqdm(range(config.EPOCHS)): model.train() totalTrainLoss = 0 totalValLoss = 0 trainCorrect = 0 valCorrect = 0 for (i, (x, y)) in enumerate(trainLoader): (x, y) = (x.to(config.DEVICE), y.to(config.DEVICE)) pred = model(x) new_shape = (len(y), 1) y = y.view(new_shape) loss = lossFunction(pred, y) loss.backward() if (i + 2) % 2 == 0: optimizer.step() optimizer.zero_grad() totalTrainLoss += loss trainCorrect += (pred.argmax(1) == y).type( torch.float).sum().item() with torch.no_grad(): model.eval() for (x, y) in valLoader: (x, y) = (x.to(config.DEVICE), y.to(config.DEVICE)) pred = model(x) new_shape = (len(y), 1) y = y.view(new_shape) totalValLoss += lossFunction(pred, y) valCorrect += (pred.argmax(1) == y).type( torch.float).sum().item() # calculate the average training and validation loss avgTrainLoss = totalTrainLoss / trainSteps avgValLoss = totalValLoss / valSteps # calculate the training and validation accuracy #trainCorrect = trainCorrect / len(trainDataset) #valCorrect = valCorrect / len(valDataset) # update our training history H["train_loss"].append(avgTrainLoss.cpu().detach().numpy()) H["val_loss"].append(avgValLoss.cpu().detach().numpy()) # print the model training and validation information print("[INFO] EPOCH: {}/{}".format(epoch + 1, config.EPOCHS)) print("Train loss: {:.6f}, Val loss: {:.6f}".format( avgTrainLoss, avgValLoss)) # display the total time needed to perform the training endTime = time.time() print("[INFO] total time taken to train the model: {:.2f}s".format( endTime - startTime)) # plot the training loss and accuracy plt.style.use("ggplot") plt.figure() plt.plot(H["train_loss"], label="train_loss") plt.plot(H["val_loss"], label="val_loss") plt.title("Training Loss on Dataset") plt.xlabel("Epoch #") plt.ylabel("Loss") plt.legend(loc="lower left") plt.savefig(INITIAL_PLOT_PATH) # serialize the model to disk torch.save(model, INTIIAL_MODEL_PATH)