156 lines
No EOL
5.2 KiB
Python
156 lines
No EOL
5.2 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
from torchvision.models import resnet50
|
|
from torchvision import transforms
|
|
from torch.utils.data import DataLoader
|
|
|
|
from tqdm import tqdm
|
|
import time
|
|
import os
|
|
from os.path import abspath
|
|
import matplotlib.pyplot as plt
|
|
|
|
import config
|
|
import dataset
|
|
|
|
# projectRoot = "/src/"
|
|
script_directory = os.path.dirname(__file__)
|
|
projectRoot = abspath(os.path.join(script_directory, "../.."))
|
|
print("Project root: " + projectRoot)
|
|
INITIAL_PLOT_PATH = os.path.join(projectRoot, "src/output/plot.png")
|
|
INTIIAL_MODEL_PATH = os.path.join(projectRoot, "src/output/model.pth")
|
|
valDatasetPath = os.path.join(projectRoot, "src/autophotographer/valDataset.pt")
|
|
trainDatasetPath = os.path.join(projectRoot, "src/autophotographer/trainDataset.pt")
|
|
|
|
# define transformations
|
|
trainTransform = transforms.Compose([
|
|
transforms.RandomResizedCrop(config.IMAGE_SIZE),
|
|
transforms.RandomHorizontalFlip(),
|
|
transforms.RandomRotation(90),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean=config.MEAN, std=config.STD)
|
|
])
|
|
valTransform = transforms.Compose([
|
|
transforms.Resize((config.IMAGE_SIZE, config.IMAGE_SIZE)),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean=config.MEAN, std=config.STD)
|
|
])
|
|
|
|
valSetLen = int(len(dataset.df) * config.VAL_SPLIT)
|
|
trainSetLen = len(dataset.df) - valSetLen
|
|
trainSet = dataset.df[:trainSetLen]
|
|
valSet = dataset.df[trainSetLen:]
|
|
|
|
print("Using " + config.DEVICE + "...")
|
|
# create data loaders
|
|
print("Getting dataloaders...")
|
|
#(trainDataset, trainLoader) = dataset.get_dataloader(trainSet,
|
|
#transforms=trainTransform, batchSize=config.FEATURE_EXTRACTION_BATCH_SIZE)
|
|
#torch.save(trainDataset, 'trainDataset.pt')
|
|
#(valDataset, valLoader) = dataset.get_dataloader(valSet,
|
|
#transforms=valTransform, batchSize=config.FEATURE_EXTRACTION_BATCH_SIZE, shuffle=False)
|
|
#torch.save(valDataset, 'valDataset.pt')
|
|
|
|
valDataset = torch.load(valDatasetPath)
|
|
valLoader = DataLoader(valDataset, batch_size=config.FEATURE_EXTRACTION_BATCH_SIZE, shuffle=False, num_workers=os.cpu_count(),
|
|
pin_memory=True if config.DEVICE == "cuda" else False)
|
|
trainDataset = torch.load(trainDatasetPath)
|
|
trainLoader = DataLoader(trainDataset, batch_size=config.FEATURE_EXTRACTION_BATCH_SIZE, shuffle=True, num_workers=os.cpu_count(),
|
|
pin_memory=True if config.DEVICE == "cuda" else False)
|
|
|
|
# Load the resnet model
|
|
model = resnet50(pretrained=True)
|
|
|
|
# Freeze all existing layers
|
|
for parameter in model.parameters():
|
|
parameter.requires_grad = False
|
|
|
|
modelOutputFeatures = model.fc.in_features
|
|
model.fc = nn.Linear(modelOutputFeatures, 1)
|
|
model = model.to(config.DEVICE)
|
|
|
|
# initialize loss function and optimizer
|
|
lossFunction = nn.L1Loss() # mean absolute error
|
|
optimizer = torch.optim.Adam(model.fc.parameters(), lr=config.LR)
|
|
|
|
|
|
# calculate steps per epoch for training and validating set
|
|
trainSteps = len(trainDataset) // config.FEATURE_EXTRACTION_BATCH_SIZE
|
|
valSteps = len(valDataset) // config.FEATURE_EXTRACTION_BATCH_SIZE
|
|
|
|
# initialize a dictionary to store training data
|
|
H = {"train_loss": [], "val_loss": []}
|
|
|
|
# loop over epochs
|
|
print("Starting training...")
|
|
startTime = time.time()
|
|
for epoch in tqdm(range(config.EPOCHS)):
|
|
model.train()
|
|
|
|
totalTrainLoss = 0
|
|
totalValLoss = 0
|
|
|
|
trainCorrect = 0
|
|
valCorrect = 0
|
|
|
|
for (i, (x, y)) in enumerate(trainLoader):
|
|
(x, y) = (x.to(config.DEVICE), y.to(config.DEVICE))
|
|
pred = model(x)
|
|
new_shape = (len(y), 1)
|
|
y = y.view(new_shape)
|
|
loss = lossFunction(pred, y)
|
|
|
|
loss.backward()
|
|
|
|
if (i + 2) % 2 == 0:
|
|
optimizer.step()
|
|
optimizer.zero_grad()
|
|
|
|
totalTrainLoss += loss
|
|
trainCorrect += (pred.argmax(1) == y).type(
|
|
torch.float).sum().item()
|
|
|
|
with torch.no_grad():
|
|
model.eval()
|
|
|
|
for (x, y) in valLoader:
|
|
(x, y) = (x.to(config.DEVICE), y.to(config.DEVICE))
|
|
|
|
pred = model(x)
|
|
new_shape = (len(y), 1)
|
|
y = y.view(new_shape)
|
|
totalValLoss += lossFunction(pred, y)
|
|
valCorrect += (pred.argmax(1) == y).type(
|
|
torch.float).sum().item()
|
|
|
|
# calculate the average training and validation loss
|
|
avgTrainLoss = totalTrainLoss / trainSteps
|
|
avgValLoss = totalValLoss / valSteps
|
|
# calculate the training and validation accuracy
|
|
#trainCorrect = trainCorrect / len(trainDataset)
|
|
#valCorrect = valCorrect / len(valDataset)
|
|
# update our training history
|
|
H["train_loss"].append(avgTrainLoss.cpu().detach().numpy())
|
|
H["val_loss"].append(avgValLoss.cpu().detach().numpy())
|
|
# print the model training and validation information
|
|
print("[INFO] EPOCH: {}/{}".format(epoch + 1, config.EPOCHS))
|
|
print("Train loss: {:.6f}, Val loss: {:.6f}".format(
|
|
avgTrainLoss, avgValLoss))
|
|
|
|
# display the total time needed to perform the training
|
|
endTime = time.time()
|
|
print("[INFO] total time taken to train the model: {:.2f}s".format(
|
|
endTime - startTime))
|
|
# plot the training loss and accuracy
|
|
plt.style.use("ggplot")
|
|
plt.figure()
|
|
plt.plot(H["train_loss"], label="train_loss")
|
|
plt.plot(H["val_loss"], label="val_loss")
|
|
plt.title("Training Loss on Dataset")
|
|
plt.xlabel("Epoch #")
|
|
plt.ylabel("Loss")
|
|
plt.legend(loc="lower left")
|
|
plt.savefig(INITIAL_PLOT_PATH)
|
|
# serialize the model to disk
|
|
torch.save(model, INTIIAL_MODEL_PATH) |