Refactored and removed dead code in model and predict

This commit is contained in:
Oscar Blue 2022-03-23 10:32:11 +00:00
parent e1b7a87bf2
commit 68ff4ea0ad
2 changed files with 47 additions and 49 deletions

View file

@ -23,7 +23,7 @@ INTIIAL_MODEL_PATH = os.path.join(projectRoot, "src/output/model.pth")
valDatasetPath = os.path.join(projectRoot, "src/autophotographer/valDataset.pt") valDatasetPath = os.path.join(projectRoot, "src/autophotographer/valDataset.pt")
trainDatasetPath = os.path.join(projectRoot, "src/autophotographer/trainDataset.pt") trainDatasetPath = os.path.join(projectRoot, "src/autophotographer/trainDataset.pt")
# define transformations # Declare transformations
trainTransform = transforms.Compose([ trainTransform = transforms.Compose([
transforms.RandomResizedCrop(config.IMAGE_SIZE), transforms.RandomResizedCrop(config.IMAGE_SIZE),
transforms.RandomHorizontalFlip(), transforms.RandomHorizontalFlip(),
@ -37,12 +37,14 @@ valTransform = transforms.Compose([
transforms.Normalize(mean=config.MEAN, std=config.STD) transforms.Normalize(mean=config.MEAN, std=config.STD)
]) ])
# Split dataset into val and train
valSetLen = int(len(dataset.df) * config.VAL_SPLIT) valSetLen = int(len(dataset.df) * config.VAL_SPLIT)
trainSetLen = len(dataset.df) - valSetLen trainSetLen = len(dataset.df) - valSetLen
trainSet = dataset.df[:trainSetLen] trainSet = dataset.df[:trainSetLen]
valSet = dataset.df[trainSetLen:] valSet = dataset.df[trainSetLen:]
print("Using " + config.DEVICE + "...") print("Using " + config.DEVICE + "...")
# create data loaders # create data loaders
print("Getting dataloaders...") print("Getting dataloaders...")
#(trainDataset, trainLoader) = dataset.get_dataloader(trainSet, #(trainDataset, trainLoader) = dataset.get_dataloader(trainSet,
@ -52,6 +54,7 @@ print("Getting dataloaders...")
#transforms=valTransform, batchSize=config.FEATURE_EXTRACTION_BATCH_SIZE, shuffle=False) #transforms=valTransform, batchSize=config.FEATURE_EXTRACTION_BATCH_SIZE, shuffle=False)
#torch.save(valDataset, 'valDataset.pt') #torch.save(valDataset, 'valDataset.pt')
# Load datasets tensors from disk
valDataset = torch.load(valDatasetPath) valDataset = torch.load(valDatasetPath)
valLoader = DataLoader(valDataset, batch_size=config.FEATURE_EXTRACTION_BATCH_SIZE, shuffle=False, num_workers=os.cpu_count(), valLoader = DataLoader(valDataset, batch_size=config.FEATURE_EXTRACTION_BATCH_SIZE, shuffle=False, num_workers=os.cpu_count(),
pin_memory=True if config.DEVICE == "cuda" else False) pin_memory=True if config.DEVICE == "cuda" else False)
@ -66,24 +69,25 @@ model = resnet50(pretrained=True)
for parameter in model.parameters(): for parameter in model.parameters():
parameter.requires_grad = False parameter.requires_grad = False
# Replace last layer with a FC single output layer
modelOutputFeatures = model.fc.in_features modelOutputFeatures = model.fc.in_features
model.fc = nn.Linear(modelOutputFeatures, 1) model.fc = nn.Linear(modelOutputFeatures, 1)
model = model.to(config.DEVICE) model = model.to(config.DEVICE)
# initialize loss function and optimizer # Initialize otimizer and loss function
lossFunction = nn.L1Loss() # mean absolute error
optimizer = torch.optim.Adam(model.fc.parameters(), lr=config.LR) optimizer = torch.optim.Adam(model.fc.parameters(), lr=config.LR)
lossFunction = nn.L1Loss() # mean absolute error
# calculate steps per epoch for training and validating set # Calculate steps for train and validation
trainSteps = len(trainDataset) // config.FEATURE_EXTRACTION_BATCH_SIZE trainSteps = len(trainDataset) // config.FEATURE_EXTRACTION_BATCH_SIZE
valSteps = len(valDataset) // config.FEATURE_EXTRACTION_BATCH_SIZE valSteps = len(valDataset) // config.FEATURE_EXTRACTION_BATCH_SIZE
# initialize a dictionary to store training data # Store training data
H = {"train_loss": [], "val_loss": []} dataDict = {"train_loss": [], "val_loss": []}
# loop over epochs # Loop over epochs
print("Starting training...") print("Training...")
startTime = time.time() startTime = time.time()
for epoch in tqdm(range(config.EPOCHS)): for epoch in tqdm(range(config.EPOCHS)):
model.train() model.train()
@ -91,8 +95,6 @@ for epoch in tqdm(range(config.EPOCHS)):
totalTrainLoss = 0 totalTrainLoss = 0
totalValLoss = 0 totalValLoss = 0
trainCorrect = 0
valCorrect = 0
for (i, (x, y)) in enumerate(trainLoader): for (i, (x, y)) in enumerate(trainLoader):
(x, y) = (x.to(config.DEVICE), y.to(config.DEVICE)) (x, y) = (x.to(config.DEVICE), y.to(config.DEVICE))
@ -108,8 +110,6 @@ for epoch in tqdm(range(config.EPOCHS)):
optimizer.zero_grad() optimizer.zero_grad()
totalTrainLoss += loss totalTrainLoss += loss
trainCorrect += (pred.argmax(1) == y).type(
torch.float).sum().item()
with torch.no_grad(): with torch.no_grad():
model.eval() model.eval()
@ -121,36 +121,32 @@ for epoch in tqdm(range(config.EPOCHS)):
new_shape = (len(y), 1) new_shape = (len(y), 1)
y = y.view(new_shape) y = y.view(new_shape)
totalValLoss += lossFunction(pred, y) totalValLoss += lossFunction(pred, y)
valCorrect += (pred.argmax(1) == y).type(
torch.float).sum().item()
# calculate the average training and validation loss # Calculate the average training and validation loss
avgTrainLoss = totalTrainLoss / trainSteps avgTrainLoss = totalTrainLoss / trainSteps
avgValLoss = totalValLoss / valSteps avgValLoss = totalValLoss / valSteps
# calculate the training and validation accuracy
#trainCorrect = trainCorrect / len(trainDataset) # Add to history
#valCorrect = valCorrect / len(valDataset) dataDict["train_loss"].append(avgTrainLoss.cpu().detach().numpy())
# update our training history dataDict["val_loss"].append(avgValLoss.cpu().detach().numpy())
H["train_loss"].append(avgTrainLoss.cpu().detach().numpy()) # Print end of epoch progress
H["val_loss"].append(avgValLoss.cpu().detach().numpy()) print("EPOCH: {}/{}".format(epoch + 1, config.EPOCHS))
# print the model training and validation information
print("[INFO] EPOCH: {}/{}".format(epoch + 1, config.EPOCHS))
print("Train loss: {:.6f}, Val loss: {:.6f}".format( print("Train loss: {:.6f}, Val loss: {:.6f}".format(
avgTrainLoss, avgValLoss)) avgTrainLoss, avgValLoss))
# display the total time needed to perform the training # Display time taken for training
endTime = time.time() endTime = time.time()
print("[INFO] total time taken to train the model: {:.2f}s".format( print("Total time taken to train the model: {:.2f}s".format(
endTime - startTime)) endTime - startTime))
# plot the training loss and accuracy # Plot the training and validation loss
plt.style.use("ggplot") plt.style.use("ggplot")
plt.figure() plt.figure()
plt.plot(H["train_loss"], label="train_loss") plt.plot(dataDict["train_loss"], label="train_loss")
plt.plot(H["val_loss"], label="val_loss") plt.plot(dataDict["val_loss"], label="val_loss")
plt.title("Training Loss on Dataset") plt.title("Training Loss on Dataset")
plt.xlabel("Epoch #") plt.xlabel("Epoch #")
plt.ylabel("Loss") plt.ylabel("Loss")
plt.legend(loc="lower left") plt.legend(loc="lower left")
plt.savefig(INITIAL_PLOT_PATH) plt.savefig(INITIAL_PLOT_PATH)
# serialize the model to disk # Save model
torch.save(model, INTIIAL_MODEL_PATH) torch.save(model, INTIIAL_MODEL_PATH)

View file

@ -24,19 +24,18 @@ parser.add_argument('image', type=os.path.abspath, metavar='image-location', nar
help='path(s) to input image(s)') help='path(s) to input image(s)')
args = vars(parser.parse_args()) args = vars(parser.parse_args())
# Load model
model = torch.load(args["model"], config.DEVICE) model = torch.load(args["model"], config.DEVICE)
model.to(config.DEVICE) model.to(config.DEVICE)
# Declare transforms # Declare transforms
# build our data pre-processing pipeline
valTransform = transforms.Compose([ valTransform = transforms.Compose([
transforms.Resize((config.IMAGE_SIZE, config.IMAGE_SIZE)), transforms.Resize((config.IMAGE_SIZE, config.IMAGE_SIZE)),
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize(mean=config.MEAN, std=config.STD) transforms.Normalize(mean=config.MEAN, std=config.STD)
]) ])
# calulate the std dev and inverse mean # Calulate the inverse std and inverse mean
# calculate the inverse mean and standard deviation
invMean = [-m/s for (m, s) in zip(config.MEAN, config.STD)] invMean = [-m/s for (m, s) in zip(config.MEAN, config.STD)]
invStd = [1/s for s in config.STD] invStd = [1/s for s in config.STD]
@ -44,7 +43,7 @@ invStd = [1/s for s in config.STD]
deNormalize = transforms.Normalize(mean=invMean, std=invStd) deNormalize = transforms.Normalize(mean=invMean, std=invStd)
# load dataset and dataloader # load dataset and dataloader
print("[INFO] loading the dataset...") print("Loading dataset...")
valSetLen = int(len(dataset.df) * config.VAL_SPLIT) valSetLen = int(len(dataset.df) * config.VAL_SPLIT)
trainSetLen = len(dataset.df) - valSetLen trainSetLen = len(dataset.df) - valSetLen
valSet = dataset.df[trainSetLen:] valSet = dataset.df[trainSetLen:]
@ -57,55 +56,58 @@ if torch.cuda.is_available():
else: else:
map_location = "cpu" map_location = "cpu"
print("[INFO] loading the model...") # Load in model
print("Loading model...")
model = torch.load(args["model"], map_location=map_location) model = torch.load(args["model"], map_location=map_location)
model.to(config.DEVICE) model.to(config.DEVICE)
model.eval() model.eval()
# Process batch
batch = next(iter(valLoader)) batch = next(iter(valLoader))
(images, ratings) = (batch[0], batch[1]) (images, ratings) = (batch[0], batch[1])
# Declare figure
fig = plt.figure("Results", figsize=(10, 10)) fig = plt.figure("Results", figsize=(10, 10))
with torch.no_grad(): with torch.no_grad():
# send the images to the device # Send the images to CPU/GPU
images = images.to(config.DEVICE) images = images.to(config.DEVICE)
# make the predictions
print("[INFO] predicting...") # Make a prediction on the images
print("Predicting...")
preds = model(images) preds = model(images)
# loop over all the batch
# Loop over each element in the batch
for i in range(0, config.PRED_BATCH_SIZE): for i in range(0, config.PRED_BATCH_SIZE):
# initalize a subplot # Initalize a subplot
ax = plt.subplot(config.PRED_BATCH_SIZE, 1, i + 1) ax = plt.subplot(config.PRED_BATCH_SIZE, 1, i + 1)
# grab the image, de-normalize it, scale the raw pixel
# intensities to the range [0, 255], and change the channel # De-normalize image, scale the pixel range to 255
# ordering from channels first tp channels last
image = images[i] image = images[i]
image = deNormalize(image).cpu().numpy() image = deNormalize(image).cpu().numpy()
image = (image * 255).astype("uint8") image = (image * 255).astype("uint8")
image = image.transpose((1, 2, 0)) image = image.transpose((1, 2, 0))
# grab the ground truth label 5 decimal places # Retrieve the ground truth to 5 decimal places
gtRating = round(ratings[i].cpu().numpy().tolist(), 5) gtRating = round(ratings[i].cpu().numpy().tolist(), 5)
# grab the predicted label 5 decimal places # Retrieve the prediction to 5 decimal places
pred = round(preds[i].cpu().numpy().tolist()[0], 5) pred = round(preds[i].cpu().numpy().tolist()[0], 5)
# calculate percentage difference # Calculate percentage difference
if pred > gtRating: if pred > gtRating:
percentage = round(((pred/gtRating) - 1) * 100, 5) percentage = round(((pred/gtRating) - 1) * 100, 5)
diff = "+" + str(percentage) diff = "+" + str(percentage)
else: else:
percentage = round(((gtRating/pred) - 1) * 100, 5) percentage = round(((gtRating/pred) - 1) * 100, 5)
diff = "-" + str(percentage) diff = "-" + str(percentage)
# add the results and image to the plot # Add the ground truth, prediction and % diff to the plot
info = "Ground Truth: {}, Predicted: {}, Diff: {}%".format(gtRating, info = "Ground Truth: {}, Predicted: {}, Diff: {}%".format(gtRating,
pred, diff) pred, diff)
plt.imshow(image) plt.imshow(image)
plt.title(info) plt.title(info)
plt.axis("off") plt.axis("off")
# show the plot # Save plot to disk
plt.tight_layout() plt.tight_layout()
date = datetime.datetime.now() date = datetime.datetime.now()
dateString = str(date.year) + "-" + str(date.month) + "-" + str(date.day) + "_" + str(date.hour) + "-" + str(date.minute) + "-" + str(date.second) dateString = str(date.year) + "-" + str(date.month) + "-" + str(date.day) + "_" + str(date.hour) + "-" + str(date.minute) + "-" + str(date.second)
PLOT_PATH = os.path.join(PLOT_PATH, "predict-plot-" + dateString + ".png") PLOT_PATH = os.path.join(PLOT_PATH, "predict-plot-" + dateString + ".png")
plt.savefig(PLOT_PATH) plt.savefig(PLOT_PATH)
plt.show()