This repository has been archived on 2022-07-15. You can view files and clone it, but cannot push or open issues or pull requests.
mmp-osp1/src/autophotographer/model.py

156 lines
No EOL
5.2 KiB
Python

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.models import resnet50
from torchvision import transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import time
import os
from os.path import abspath
import matplotlib.pyplot as plt
import config
import dataset
# projectRoot = "/src/"
script_directory = os.path.dirname(__file__)
projectRoot = abspath(os.path.join(script_directory, "../.."))
print("Project root: " + projectRoot)
INITIAL_PLOT_PATH = os.path.join(projectRoot, "src/output/plot.png")
INTIIAL_MODEL_PATH = os.path.join(projectRoot, "src/output/model.pth")
valDatasetPath = os.path.join(projectRoot, "src/autophotographer/valDataset.pt")
trainDatasetPath = os.path.join(projectRoot, "src/autophotographer/trainDataset.pt")
# define transformations
trainTransform = transforms.Compose([
transforms.RandomResizedCrop(config.IMAGE_SIZE),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(90),
transforms.ToTensor(),
transforms.Normalize(mean=config.MEAN, std=config.STD)
])
valTransform = transforms.Compose([
transforms.Resize((config.IMAGE_SIZE, config.IMAGE_SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=config.MEAN, std=config.STD)
])
valSetLen = int(len(dataset.df) * config.VAL_SPLIT)
trainSetLen = len(dataset.df) - valSetLen
trainSet = dataset.df[:trainSetLen]
valSet = dataset.df[trainSetLen:]
print("Using " + config.DEVICE + "...")
# create data loaders
print("Getting dataloaders...")
#(trainDataset, trainLoader) = dataset.get_dataloader(trainSet,
#transforms=trainTransform, batchSize=config.FEATURE_EXTRACTION_BATCH_SIZE)
#torch.save(trainDataset, 'trainDataset.pt')
#(valDataset, valLoader) = dataset.get_dataloader(valSet,
#transforms=valTransform, batchSize=config.FEATURE_EXTRACTION_BATCH_SIZE, shuffle=False)
#torch.save(valDataset, 'valDataset.pt')
valDataset = torch.load(valDatasetPath)
valLoader = DataLoader(valDataset, batch_size=config.FEATURE_EXTRACTION_BATCH_SIZE, shuffle=False, num_workers=os.cpu_count(),
pin_memory=True if config.DEVICE == "cuda" else False)
trainDataset = torch.load(trainDatasetPath)
trainLoader = DataLoader(trainDataset, batch_size=config.FEATURE_EXTRACTION_BATCH_SIZE, shuffle=True, num_workers=os.cpu_count(),
pin_memory=True if config.DEVICE == "cuda" else False)
# Load the resnet model
model = resnet50(pretrained=True)
# Freeze all existing layers
for parameter in model.parameters():
parameter.requires_grad = False
modelOutputFeatures = model.fc.in_features
model.fc = nn.Linear(modelOutputFeatures, 1)
model = model.to(config.DEVICE)
# initialize loss function and optimizer
lossFunction = nn.L1Loss() # mean absolute error
optimizer = torch.optim.Adam(model.fc.parameters(), lr=config.LR)
# calculate steps per epoch for training and validating set
trainSteps = len(trainDataset) // config.FEATURE_EXTRACTION_BATCH_SIZE
valSteps = len(valDataset) // config.FEATURE_EXTRACTION_BATCH_SIZE
# initialize a dictionary to store training data
H = {"train_loss": [], "val_loss": []}
# loop over epochs
print("Starting training...")
startTime = time.time()
for epoch in tqdm(range(config.EPOCHS)):
model.train()
totalTrainLoss = 0
totalValLoss = 0
trainCorrect = 0
valCorrect = 0
for (i, (x, y)) in enumerate(trainLoader):
(x, y) = (x.to(config.DEVICE), y.to(config.DEVICE))
pred = model(x)
new_shape = (len(y), 1)
y = y.view(new_shape)
loss = lossFunction(pred, y)
loss.backward()
if (i + 2) % 2 == 0:
optimizer.step()
optimizer.zero_grad()
totalTrainLoss += loss
trainCorrect += (pred.argmax(1) == y).type(
torch.float).sum().item()
with torch.no_grad():
model.eval()
for (x, y) in valLoader:
(x, y) = (x.to(config.DEVICE), y.to(config.DEVICE))
pred = model(x)
new_shape = (len(y), 1)
y = y.view(new_shape)
totalValLoss += lossFunction(pred, y)
valCorrect += (pred.argmax(1) == y).type(
torch.float).sum().item()
# calculate the average training and validation loss
avgTrainLoss = totalTrainLoss / trainSteps
avgValLoss = totalValLoss / valSteps
# calculate the training and validation accuracy
#trainCorrect = trainCorrect / len(trainDataset)
#valCorrect = valCorrect / len(valDataset)
# update our training history
H["train_loss"].append(avgTrainLoss.cpu().detach().numpy())
H["val_loss"].append(avgValLoss.cpu().detach().numpy())
# print the model training and validation information
print("[INFO] EPOCH: {}/{}".format(epoch + 1, config.EPOCHS))
print("Train loss: {:.6f}, Val loss: {:.6f}".format(
avgTrainLoss, avgValLoss))
# display the total time needed to perform the training
endTime = time.time()
print("[INFO] total time taken to train the model: {:.2f}s".format(
endTime - startTime))
# plot the training loss and accuracy
plt.style.use("ggplot")
plt.figure()
plt.plot(H["train_loss"], label="train_loss")
plt.plot(H["val_loss"], label="val_loss")
plt.title("Training Loss on Dataset")
plt.xlabel("Epoch #")
plt.ylabel("Loss")
plt.legend(loc="lower left")
plt.savefig(INITIAL_PLOT_PATH)
# serialize the model to disk
torch.save(model, INTIIAL_MODEL_PATH)