Added initial training models and results
This commit is contained in:
parent
eccc80425a
commit
e1b7a87bf2
15 changed files with 4146 additions and 2 deletions
2
.gitattributes
vendored
2
.gitattributes
vendored
|
@ -2,3 +2,5 @@
|
||||||
*.webm filter=lfs diff=lfs merge=lfs -text
|
*.webm filter=lfs diff=lfs merge=lfs -text
|
||||||
*.jpg filter=lfs diff=lfs merge=lfs -text
|
*.jpg filter=lfs diff=lfs merge=lfs -text
|
||||||
*.pt filter=lfs diff=lfs merge=lfs -text
|
*.pt filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.pth filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.png filter=lfs diff=lfs merge=lfs -text
|
||||||
|
|
BIN
docs/mid-project.odp
Normal file
BIN
docs/mid-project.odp
Normal file
Binary file not shown.
0
src/__init__.py
Normal file
0
src/__init__.py
Normal file
|
@ -8,7 +8,7 @@ import numpy as np
|
||||||
from sklearn.model_selection import train_test_split
|
from sklearn.model_selection import train_test_split
|
||||||
import cv2
|
import cv2
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import config
|
from . import config
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
|
@ -71,7 +71,7 @@ model.fc = nn.Linear(modelOutputFeatures, 1)
|
||||||
model = model.to(config.DEVICE)
|
model = model.to(config.DEVICE)
|
||||||
|
|
||||||
# initialize loss function and optimizer
|
# initialize loss function and optimizer
|
||||||
lossFunction = nn.L1Loss()
|
lossFunction = nn.L1Loss() # mean absolute error
|
||||||
optimizer = torch.optim.Adam(model.fc.parameters(), lr=config.LR)
|
optimizer = torch.optim.Adam(model.fc.parameters(), lr=config.LR)
|
||||||
|
|
||||||
|
|
||||||
|
|
BIN
src/output/21-03-2022/model.pth
(Stored with Git LFS)
Normal file
BIN
src/output/21-03-2022/model.pth
(Stored with Git LFS)
Normal file
Binary file not shown.
BIN
src/output/21-03-2022/plot.png
(Stored with Git LFS)
Normal file
BIN
src/output/21-03-2022/plot.png
(Stored with Git LFS)
Normal file
Binary file not shown.
BIN
src/output/22-03-2022/model.pth
(Stored with Git LFS)
Normal file
BIN
src/output/22-03-2022/model.pth
(Stored with Git LFS)
Normal file
Binary file not shown.
BIN
src/output/22-03-2022/plot.png
(Stored with Git LFS)
Normal file
BIN
src/output/22-03-2022/plot.png
(Stored with Git LFS)
Normal file
Binary file not shown.
4007
src/output/22-03-2022/slurm-3608.out
Normal file
4007
src/output/22-03-2022/slurm-3608.out
Normal file
File diff suppressed because it is too large
Load diff
BIN
src/output/model.pth
(Stored with Git LFS)
Normal file
BIN
src/output/model.pth
(Stored with Git LFS)
Normal file
Binary file not shown.
BIN
src/output/plot.png
(Stored with Git LFS)
Normal file
BIN
src/output/plot.png
(Stored with Git LFS)
Normal file
Binary file not shown.
BIN
src/output/predict-plot-2022-3-23_4-40-19.png
(Stored with Git LFS)
Normal file
BIN
src/output/predict-plot-2022-3-23_4-40-19.png
(Stored with Git LFS)
Normal file
Binary file not shown.
BIN
src/output/predict-plot.png
(Stored with Git LFS)
Normal file
BIN
src/output/predict-plot.png
(Stored with Git LFS)
Normal file
Binary file not shown.
111
src/predict.py
Normal file
111
src/predict.py
Normal file
|
@ -0,0 +1,111 @@
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
from os.path import abspath
|
||||||
|
from autophotographer import config
|
||||||
|
from autophotographer import dataset
|
||||||
|
from torchvision import transforms
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from torch import nn
|
||||||
|
import torch
|
||||||
|
import argparse
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
# set project root for fetching files using relative file paths
|
||||||
|
script_directory = os.path.dirname(__file__)
|
||||||
|
projectRoot = abspath(os.path.join(script_directory, "../"))
|
||||||
|
print(projectRoot)
|
||||||
|
PLOT_PATH = os.path.join(projectRoot, "src/output/")
|
||||||
|
|
||||||
|
# parse arguments for image to predict and model to use
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("-m", "--model", type=os.path.abspath, required=True,
|
||||||
|
help="path to trained model model")
|
||||||
|
parser.add_argument('image', type=os.path.abspath, metavar='image-location', nargs='+',
|
||||||
|
help='path(s) to input image(s)')
|
||||||
|
args = vars(parser.parse_args())
|
||||||
|
|
||||||
|
model = torch.load(args["model"], config.DEVICE)
|
||||||
|
model.to(config.DEVICE)
|
||||||
|
|
||||||
|
# Declare transforms
|
||||||
|
# build our data pre-processing pipeline
|
||||||
|
valTransform = transforms.Compose([
|
||||||
|
transforms.Resize((config.IMAGE_SIZE, config.IMAGE_SIZE)),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(mean=config.MEAN, std=config.STD)
|
||||||
|
])
|
||||||
|
|
||||||
|
# calulate the std dev and inverse mean
|
||||||
|
# calculate the inverse mean and standard deviation
|
||||||
|
invMean = [-m/s for (m, s) in zip(config.MEAN, config.STD)]
|
||||||
|
invStd = [1/s for s in config.STD]
|
||||||
|
|
||||||
|
# define de-normalization transform
|
||||||
|
deNormalize = transforms.Normalize(mean=invMean, std=invStd)
|
||||||
|
|
||||||
|
# load dataset and dataloader
|
||||||
|
print("[INFO] loading the dataset...")
|
||||||
|
valSetLen = int(len(dataset.df) * config.VAL_SPLIT)
|
||||||
|
trainSetLen = len(dataset.df) - valSetLen
|
||||||
|
valSet = dataset.df[trainSetLen:]
|
||||||
|
(valDataset, valLoader) = dataset.get_dataloader(valSet,
|
||||||
|
transforms=valTransform, batchSize=config.PRED_BATCH_SIZE,
|
||||||
|
shuffle=True)
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
map_location = lambda storage, loc: storage.cuda()
|
||||||
|
else:
|
||||||
|
map_location = "cpu"
|
||||||
|
|
||||||
|
print("[INFO] loading the model...")
|
||||||
|
model = torch.load(args["model"], map_location=map_location)
|
||||||
|
|
||||||
|
model.to(config.DEVICE)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
batch = next(iter(valLoader))
|
||||||
|
(images, ratings) = (batch[0], batch[1])
|
||||||
|
|
||||||
|
fig = plt.figure("Results", figsize=(10, 10))
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
# send the images to the device
|
||||||
|
images = images.to(config.DEVICE)
|
||||||
|
# make the predictions
|
||||||
|
print("[INFO] predicting...")
|
||||||
|
preds = model(images)
|
||||||
|
# loop over all the batch
|
||||||
|
for i in range(0, config.PRED_BATCH_SIZE):
|
||||||
|
# initalize a subplot
|
||||||
|
ax = plt.subplot(config.PRED_BATCH_SIZE, 1, i + 1)
|
||||||
|
# grab the image, de-normalize it, scale the raw pixel
|
||||||
|
# intensities to the range [0, 255], and change the channel
|
||||||
|
# ordering from channels first tp channels last
|
||||||
|
image = images[i]
|
||||||
|
image = deNormalize(image).cpu().numpy()
|
||||||
|
image = (image * 255).astype("uint8")
|
||||||
|
image = image.transpose((1, 2, 0))
|
||||||
|
# grab the ground truth label 5 decimal places
|
||||||
|
gtRating = round(ratings[i].cpu().numpy().tolist(), 5)
|
||||||
|
# grab the predicted label 5 decimal places
|
||||||
|
pred = round(preds[i].cpu().numpy().tolist()[0], 5)
|
||||||
|
# calculate percentage difference
|
||||||
|
if pred > gtRating:
|
||||||
|
percentage = round(((pred/gtRating) - 1) * 100, 5)
|
||||||
|
diff = "+" + str(percentage)
|
||||||
|
else:
|
||||||
|
percentage = round(((gtRating/pred) - 1) * 100, 5)
|
||||||
|
diff = "-" + str(percentage)
|
||||||
|
# add the results and image to the plot
|
||||||
|
info = "Ground Truth: {}, Predicted: {}, Diff: {}%".format(gtRating,
|
||||||
|
pred, diff)
|
||||||
|
plt.imshow(image)
|
||||||
|
plt.title(info)
|
||||||
|
plt.axis("off")
|
||||||
|
# show the plot
|
||||||
|
plt.tight_layout()
|
||||||
|
date = datetime.datetime.now()
|
||||||
|
dateString = str(date.year) + "-" + str(date.month) + "-" + str(date.day) + "_" + str(date.hour) + "-" + str(date.minute) + "-" + str(date.second)
|
||||||
|
PLOT_PATH = os.path.join(PLOT_PATH, "predict-plot-" + dateString + ".png")
|
||||||
|
plt.savefig(PLOT_PATH)
|
||||||
|
plt.show()
|
Reference in a new issue