Refactored for portability
This commit is contained in:
parent
9cfdb0b91f
commit
eccc80425a
4 changed files with 255518 additions and 8 deletions
|
@ -10,11 +10,9 @@ VAL_SPLIT = 0.1
|
||||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
FEATURE_EXTRACTION_BATCH_SIZE = 256
|
FEATURE_EXTRACTION_BATCH_SIZE = 256
|
||||||
FINETUNE_BATCH_SIZE = 64
|
|
||||||
PRED_BATCH_SIZE = 4
|
PRED_BATCH_SIZE = 4
|
||||||
EPOCHS = 20
|
EPOCHS = 20
|
||||||
LR = 0.001
|
LR = 0.001
|
||||||
LR_FINETUNE = 0.0005 # REMOVE
|
|
||||||
IMAGE_SIZE = 32
|
IMAGE_SIZE = 32
|
||||||
|
|
||||||
WARMUP_PLOT = os.path.join("output", "plot.png")
|
WARMUP_PLOT = os.path.join("output", "plot.png")
|
||||||
|
|
255509
src/autophotographer/dataframe.csv
Normal file
255509
src/autophotographer/dataframe.csv
Normal file
File diff suppressed because it is too large
Load diff
|
@ -24,10 +24,10 @@ projectRoot = abspath(os.path.join(script_directory, "../.."))
|
||||||
#projectRoot = "/src/"
|
#projectRoot = "/src/"
|
||||||
print(projectRoot)
|
print(projectRoot)
|
||||||
tensorImagesPath = os.path.join(projectRoot, "src/autophotographer/tensorImages.pt")
|
tensorImagesPath = os.path.join(projectRoot, "src/autophotographer/tensorImages.pt")
|
||||||
tensorImagesPath = os.path.join(projectRoot, "src/autophotographer/tensorImages.pt")
|
|
||||||
tensorRatingsPath = os.path.join(projectRoot, "src/autophotographer/tensorRatings.pt")
|
tensorRatingsPath = os.path.join(projectRoot, "src/autophotographer/tensorRatings.pt")
|
||||||
tensorArrayPath = os.path.join(projectRoot, "src/autophotographer/tensorArray.pt")
|
tensorArrayPath = os.path.join(projectRoot, "src/autophotographer/tensorArray.pt")
|
||||||
filePathRatings = os.path.join(projectRoot, "data/ratings.txt")
|
filePathRatings = os.path.join(projectRoot, "data/ratings.txt")
|
||||||
|
dataframePath = os.path.join(projectRoot, "src/autophotographer/dataframe.csv")
|
||||||
|
|
||||||
if not datasetDir == "":
|
if not datasetDir == "":
|
||||||
filePathStyle = datasetDir + "AVA/style_image_lists/test.multilab"
|
filePathStyle = datasetDir + "AVA/style_image_lists/test.multilab"
|
||||||
|
@ -151,7 +151,8 @@ def build_dataframe(df, imgPath):
|
||||||
df['path'] = imagePaths
|
df['path'] = imagePaths
|
||||||
return df
|
return df
|
||||||
|
|
||||||
df = build_dataframe(remove_entries_for_missing_images(load_image_ratings(), imgPath), imgPath)
|
#df = build_dataframe(remove_entries_for_missing_images(load_image_ratings(), imgPath), imgPath)
|
||||||
|
df = pd.read_csv(dataframePath, index_col = 0)
|
||||||
|
|
||||||
def create_tensor_array():
|
def create_tensor_array():
|
||||||
tensorArray = []
|
tensorArray = []
|
||||||
|
|
|
@ -18,8 +18,10 @@ import dataset
|
||||||
script_directory = os.path.dirname(__file__)
|
script_directory = os.path.dirname(__file__)
|
||||||
projectRoot = abspath(os.path.join(script_directory, "../.."))
|
projectRoot = abspath(os.path.join(script_directory, "../.."))
|
||||||
print("Project root: " + projectRoot)
|
print("Project root: " + projectRoot)
|
||||||
INITIAL_PLOT_PATH = projectRoot + "/src/output/plot.png"
|
INITIAL_PLOT_PATH = os.path.join(projectRoot, "src/output/plot.png")
|
||||||
INTIIAL_MODEL_PATH = projectRoot + "/src/output/model.pth"
|
INTIIAL_MODEL_PATH = os.path.join(projectRoot, "src/output/model.pth")
|
||||||
|
valDatasetPath = os.path.join(projectRoot, "src/autophotographer/valDataset.pt")
|
||||||
|
trainDatasetPath = os.path.join(projectRoot, "src/autophotographer/trainDataset.pt")
|
||||||
|
|
||||||
# define transformations
|
# define transformations
|
||||||
trainTransform = transforms.Compose([
|
trainTransform = transforms.Compose([
|
||||||
|
@ -50,10 +52,10 @@ print("Getting dataloaders...")
|
||||||
#transforms=valTransform, batchSize=config.FEATURE_EXTRACTION_BATCH_SIZE, shuffle=False)
|
#transforms=valTransform, batchSize=config.FEATURE_EXTRACTION_BATCH_SIZE, shuffle=False)
|
||||||
#torch.save(valDataset, 'valDataset.pt')
|
#torch.save(valDataset, 'valDataset.pt')
|
||||||
|
|
||||||
valDataset = torch.load("/src/src/autophotographer/valDataset.pt")
|
valDataset = torch.load(valDatasetPath)
|
||||||
valLoader = DataLoader(valDataset, batch_size=config.FEATURE_EXTRACTION_BATCH_SIZE, shuffle=False, num_workers=os.cpu_count(),
|
valLoader = DataLoader(valDataset, batch_size=config.FEATURE_EXTRACTION_BATCH_SIZE, shuffle=False, num_workers=os.cpu_count(),
|
||||||
pin_memory=True if config.DEVICE == "cuda" else False)
|
pin_memory=True if config.DEVICE == "cuda" else False)
|
||||||
trainDataset = torch.load("/src/src/autophotographer/trainDataset.pt")
|
trainDataset = torch.load(trainDatasetPath)
|
||||||
trainLoader = DataLoader(trainDataset, batch_size=config.FEATURE_EXTRACTION_BATCH_SIZE, shuffle=True, num_workers=os.cpu_count(),
|
trainLoader = DataLoader(trainDataset, batch_size=config.FEATURE_EXTRACTION_BATCH_SIZE, shuffle=True, num_workers=os.cpu_count(),
|
||||||
pin_memory=True if config.DEVICE == "cuda" else False)
|
pin_memory=True if config.DEVICE == "cuda" else False)
|
||||||
|
|
||||||
|
|
Reference in a new issue