Refactored and removed dead code in model and predict

This commit is contained in:
Oscar Blue 2022-03-23 10:32:11 +00:00
parent e1b7a87bf2
commit 68ff4ea0ad
2 changed files with 47 additions and 49 deletions

View file

@ -23,7 +23,7 @@ INTIIAL_MODEL_PATH = os.path.join(projectRoot, "src/output/model.pth")
valDatasetPath = os.path.join(projectRoot, "src/autophotographer/valDataset.pt")
trainDatasetPath = os.path.join(projectRoot, "src/autophotographer/trainDataset.pt")
# define transformations
# Declare transformations
trainTransform = transforms.Compose([
transforms.RandomResizedCrop(config.IMAGE_SIZE),
transforms.RandomHorizontalFlip(),
@ -37,12 +37,14 @@ valTransform = transforms.Compose([
transforms.Normalize(mean=config.MEAN, std=config.STD)
])
# Split dataset into val and train
valSetLen = int(len(dataset.df) * config.VAL_SPLIT)
trainSetLen = len(dataset.df) - valSetLen
trainSet = dataset.df[:trainSetLen]
valSet = dataset.df[trainSetLen:]
print("Using " + config.DEVICE + "...")
# create data loaders
print("Getting dataloaders...")
#(trainDataset, trainLoader) = dataset.get_dataloader(trainSet,
@ -52,6 +54,7 @@ print("Getting dataloaders...")
#transforms=valTransform, batchSize=config.FEATURE_EXTRACTION_BATCH_SIZE, shuffle=False)
#torch.save(valDataset, 'valDataset.pt')
# Load datasets tensors from disk
valDataset = torch.load(valDatasetPath)
valLoader = DataLoader(valDataset, batch_size=config.FEATURE_EXTRACTION_BATCH_SIZE, shuffle=False, num_workers=os.cpu_count(),
pin_memory=True if config.DEVICE == "cuda" else False)
@ -66,24 +69,25 @@ model = resnet50(pretrained=True)
for parameter in model.parameters():
parameter.requires_grad = False
# Replace last layer with a FC single output layer
modelOutputFeatures = model.fc.in_features
model.fc = nn.Linear(modelOutputFeatures, 1)
model = model.to(config.DEVICE)
# initialize loss function and optimizer
lossFunction = nn.L1Loss() # mean absolute error
# Initialize otimizer and loss function
optimizer = torch.optim.Adam(model.fc.parameters(), lr=config.LR)
lossFunction = nn.L1Loss() # mean absolute error
# calculate steps per epoch for training and validating set
# Calculate steps for train and validation
trainSteps = len(trainDataset) // config.FEATURE_EXTRACTION_BATCH_SIZE
valSteps = len(valDataset) // config.FEATURE_EXTRACTION_BATCH_SIZE
# initialize a dictionary to store training data
H = {"train_loss": [], "val_loss": []}
# Store training data
dataDict = {"train_loss": [], "val_loss": []}
# loop over epochs
print("Starting training...")
# Loop over epochs
print("Training...")
startTime = time.time()
for epoch in tqdm(range(config.EPOCHS)):
model.train()
@ -91,8 +95,6 @@ for epoch in tqdm(range(config.EPOCHS)):
totalTrainLoss = 0
totalValLoss = 0
trainCorrect = 0
valCorrect = 0
for (i, (x, y)) in enumerate(trainLoader):
(x, y) = (x.to(config.DEVICE), y.to(config.DEVICE))
@ -108,8 +110,6 @@ for epoch in tqdm(range(config.EPOCHS)):
optimizer.zero_grad()
totalTrainLoss += loss
trainCorrect += (pred.argmax(1) == y).type(
torch.float).sum().item()
with torch.no_grad():
model.eval()
@ -121,36 +121,32 @@ for epoch in tqdm(range(config.EPOCHS)):
new_shape = (len(y), 1)
y = y.view(new_shape)
totalValLoss += lossFunction(pred, y)
valCorrect += (pred.argmax(1) == y).type(
torch.float).sum().item()
# calculate the average training and validation loss
# Calculate the average training and validation loss
avgTrainLoss = totalTrainLoss / trainSteps
avgValLoss = totalValLoss / valSteps
# calculate the training and validation accuracy
#trainCorrect = trainCorrect / len(trainDataset)
#valCorrect = valCorrect / len(valDataset)
# update our training history
H["train_loss"].append(avgTrainLoss.cpu().detach().numpy())
H["val_loss"].append(avgValLoss.cpu().detach().numpy())
# print the model training and validation information
print("[INFO] EPOCH: {}/{}".format(epoch + 1, config.EPOCHS))
# Add to history
dataDict["train_loss"].append(avgTrainLoss.cpu().detach().numpy())
dataDict["val_loss"].append(avgValLoss.cpu().detach().numpy())
# Print end of epoch progress
print("EPOCH: {}/{}".format(epoch + 1, config.EPOCHS))
print("Train loss: {:.6f}, Val loss: {:.6f}".format(
avgTrainLoss, avgValLoss))
# display the total time needed to perform the training
# Display time taken for training
endTime = time.time()
print("[INFO] total time taken to train the model: {:.2f}s".format(
print("Total time taken to train the model: {:.2f}s".format(
endTime - startTime))
# plot the training loss and accuracy
# Plot the training and validation loss
plt.style.use("ggplot")
plt.figure()
plt.plot(H["train_loss"], label="train_loss")
plt.plot(H["val_loss"], label="val_loss")
plt.plot(dataDict["train_loss"], label="train_loss")
plt.plot(dataDict["val_loss"], label="val_loss")
plt.title("Training Loss on Dataset")
plt.xlabel("Epoch #")
plt.ylabel("Loss")
plt.legend(loc="lower left")
plt.savefig(INITIAL_PLOT_PATH)
# serialize the model to disk
# Save model
torch.save(model, INTIIAL_MODEL_PATH)

View file

@ -24,19 +24,18 @@ parser.add_argument('image', type=os.path.abspath, metavar='image-location', nar
help='path(s) to input image(s)')
args = vars(parser.parse_args())
# Load model
model = torch.load(args["model"], config.DEVICE)
model.to(config.DEVICE)
# Declare transforms
# build our data pre-processing pipeline
valTransform = transforms.Compose([
transforms.Resize((config.IMAGE_SIZE, config.IMAGE_SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=config.MEAN, std=config.STD)
])
# calulate the std dev and inverse mean
# calculate the inverse mean and standard deviation
# Calulate the inverse std and inverse mean
invMean = [-m/s for (m, s) in zip(config.MEAN, config.STD)]
invStd = [1/s for s in config.STD]
@ -44,7 +43,7 @@ invStd = [1/s for s in config.STD]
deNormalize = transforms.Normalize(mean=invMean, std=invStd)
# load dataset and dataloader
print("[INFO] loading the dataset...")
print("Loading dataset...")
valSetLen = int(len(dataset.df) * config.VAL_SPLIT)
trainSetLen = len(dataset.df) - valSetLen
valSet = dataset.df[trainSetLen:]
@ -57,55 +56,58 @@ if torch.cuda.is_available():
else:
map_location = "cpu"
print("[INFO] loading the model...")
# Load in model
print("Loading model...")
model = torch.load(args["model"], map_location=map_location)
model.to(config.DEVICE)
model.eval()
# Process batch
batch = next(iter(valLoader))
(images, ratings) = (batch[0], batch[1])
# Declare figure
fig = plt.figure("Results", figsize=(10, 10))
with torch.no_grad():
# send the images to the device
# Send the images to CPU/GPU
images = images.to(config.DEVICE)
# make the predictions
print("[INFO] predicting...")
# Make a prediction on the images
print("Predicting...")
preds = model(images)
# loop over all the batch
# Loop over each element in the batch
for i in range(0, config.PRED_BATCH_SIZE):
# initalize a subplot
# Initalize a subplot
ax = plt.subplot(config.PRED_BATCH_SIZE, 1, i + 1)
# grab the image, de-normalize it, scale the raw pixel
# intensities to the range [0, 255], and change the channel
# ordering from channels first tp channels last
# De-normalize image, scale the pixel range to 255
image = images[i]
image = deNormalize(image).cpu().numpy()
image = (image * 255).astype("uint8")
image = image.transpose((1, 2, 0))
# grab the ground truth label 5 decimal places
# Retrieve the ground truth to 5 decimal places
gtRating = round(ratings[i].cpu().numpy().tolist(), 5)
# grab the predicted label 5 decimal places
# Retrieve the prediction to 5 decimal places
pred = round(preds[i].cpu().numpy().tolist()[0], 5)
# calculate percentage difference
# Calculate percentage difference
if pred > gtRating:
percentage = round(((pred/gtRating) - 1) * 100, 5)
diff = "+" + str(percentage)
else:
percentage = round(((gtRating/pred) - 1) * 100, 5)
diff = "-" + str(percentage)
# add the results and image to the plot
# Add the ground truth, prediction and % diff to the plot
info = "Ground Truth: {}, Predicted: {}, Diff: {}%".format(gtRating,
pred, diff)
plt.imshow(image)
plt.title(info)
plt.axis("off")
# show the plot
# Save plot to disk
plt.tight_layout()
date = datetime.datetime.now()
dateString = str(date.year) + "-" + str(date.month) + "-" + str(date.day) + "_" + str(date.hour) + "-" + str(date.minute) + "-" + str(date.second)
PLOT_PATH = os.path.join(PLOT_PATH, "predict-plot-" + dateString + ".png")
plt.savefig(PLOT_PATH)
plt.show()