Commit 13afb2ac authored by Oliver Wirth's avatar Oliver Wirth
Browse files

Cleanup and additional documentation

parent 9e90b432
......@@ -94,6 +94,13 @@
" return df, label"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Pre-processing transformations."
]
},
{
"cell_type": "code",
"execution_count": null,
......@@ -108,8 +115,20 @@
"])\n",
"\n",
"# Load train and test set\n",
"trainset = RegionDataset(train_path / '004', all_labels, transform = transform)\n",
"testset = RegionDataset(test_path / 'dummy', all_labels, transform = transform)"
"trainset = RegionDataset(train_path / '004', all_labels, transform = transform)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Data split and batching hyperparameters.\n",
"Three different collate-functions are present:\n",
"- `collate_pack` packs timeseries in to a PyTorch `PackedSequence` object, which allows masking in recurrent networks\n",
"- `collate_crop` crops timeseries to the length of the shortest series in the batch, data points are removed from the front as later data points likely contribute more information toward the classification goal\n",
"- `collate_pad` pads sequences to the length of the longest sequence in the batch\n",
"\n",
"Experiments showed that `collate_crop` seems to work best for CNNs."
]
},
{
......@@ -151,7 +170,6 @@
"batch_size = 8\n",
"collate_fn = collate_crop\n",
"trainloader = DataLoader(trainset, batch_size = batch_size, collate_fn = collate_fn, shuffle = True)\n",
"testloader = DataLoader(testset , batch_size = batch_size, collate_fn = collate_fn, shuffle = False)\n",
"valloader = DataLoader(valset , batch_size = batch_size, collate_fn = collate_fn, shuffle = False)"
]
},
......@@ -159,7 +177,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Models"
"## Models\n",
"\n",
"Various RNN and CNN models, starting with a LSTM-based RNN."
]
},
{
......@@ -196,6 +216,13 @@
" return x"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Similar to above network, but with a GRU instead of LSTM."
]
},
{
"cell_type": "code",
"execution_count": null,
......@@ -230,6 +257,13 @@
" return x"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"One-dimensional CNN with three convolutional layers."
]
},
{
"cell_type": "code",
"execution_count": null,
......@@ -282,6 +316,13 @@
" return x"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Another CNN with depth-wise separable convolutions, to reduce the amount of parameters, while maintaining classification performance."
]
},
{
"cell_type": "code",
"execution_count": null,
......@@ -347,6 +388,13 @@
" return x"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Additional regularization in the form of batch normalization. This and the previous model proved to be best-performing on single training regions."
]
},
{
"cell_type": "code",
"execution_count": null,
......@@ -508,7 +556,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train model"
"## Train model\n",
"\n",
"Train model for the specified number of epochs, or when the learning rate falls below a certain point (by the scheduler)."
]
},
{
......@@ -564,6 +614,7 @@
"metadata": {},
"outputs": [],
"source": [
"# Plot train/val loss/accuracy\n",
"fig, ax = plt.subplots(1, 2, figsize = (16, 6))\n",
"g = sns.lineplot(data = loss_stats, ax = ax[0])\n",
"g = sns.lineplot(data = acc_stats, ax = ax[1])\n",
......@@ -575,25 +626,6 @@
"\n",
"fig.savefig(run_path / 'loss_acc_plot.png')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Evaluation"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Evaluate trained model on test set\n",
"running_loss, running_correct = train_loop(model, testloader, criterion)\n",
"n = len(testset)\n",
"print(f'Evaluation after {epoch} epochs: loss = {(running_loss / n):.4f}, acc = {(running_correct / n):.4f}')"
]
}
],
"metadata": {
......
......@@ -98,7 +98,10 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Region model"
"## Region model\n",
"\n",
"CNN model trained on timeseries from a single region.\n",
"Employing depth-wise separable convolutions for reducing number of parameters and dropout, batch normalization for regularization."
]
},
{
......@@ -227,6 +230,13 @@
" return running_loss, running_correct"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Define pre-processing operations and training hyperparameters."
]
},
{
"cell_type": "code",
"execution_count": null,
......@@ -265,6 +275,13 @@
"patience = 10"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Routine for loading a region's dataset, and training a model until convergence."
]
},
{
"cell_type": "code",
"execution_count": null,
......@@ -354,14 +371,24 @@
" print(f'best validation acc = {acc_stats[\"val\"].max()}')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Blacklist already trained regions to save time."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"# Exclude already trained models (from 2021-06-04 14:23:10.737430)\n",
"blacklist = ['006', '018', '020', '049', '052', '057', '060', '064']\n",
"# Exclude already trained models (from 2021-06-04 14:23:10.737430, 2021-06-05 11:33:17.875281)\n",
"# blacklist = ['004', '011', '037', '006', '017', '018', '020', '029', '049', '052', '055', '057', '060', '064']\n",
"blacklist = []\n",
"\n",
"# Train models for all regions\n",
"for region_path in train_path.iterdir() :\n",
......@@ -373,7 +400,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Evaluation"
"## Evaluation\n",
"\n",
"Evaluate trained models on test regions, by combining the models into an ensemble."
]
},
{
......@@ -386,21 +415,17 @@
"class RegionEnsemble(nn.Module) :\n",
" \n",
" def __init__(self, run_path) :\n",
" self.models = []\n",
" super().__init__()\n",
" self.models = nn.ModuleDict()\n",
" # Load all previously trained models\n",
" for region_run_path in run_path.iterdir() :\n",
" if region_run_path.is_dir() :\n",
" model = torch.load(region_run_path / 'model_best.pt')\n",
" self.models.append(model)\n",
" \n",
" def train(self, mode = True) :\n",
" for model in self.models :\n",
" model.train(mode)\n",
" return super().train(mode)\n",
" self.models[region_run_path.name] = model\n",
" \n",
" def forward(self, x) :\n",
" # Get region model outputs\n",
" x = torch.stack([model(x) for model in self.models], dim = 1)\n",
" x = torch.stack([model(x) for model in self.models.values()], dim = 1)\n",
" # Convert to probabilities\n",
" x = x.softmax(dim = 2)\n",
" # Average probabilities (confidence-weighted majority vote)\n",
......@@ -408,6 +433,13 @@
" return x"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Adjust `run_path` if necessary."
]
},
{
"cell_type": "code",
"execution_count": null,
......@@ -415,10 +447,18 @@
"outputs": [],
"source": [
"# Load ensemble\n",
"ensemble = RegionEnsemble(runs_path / 'best')\n",
"# run_path = runs_path / 'best'\n",
"ensemble = RegionEnsemble(run_path)\n",
"ensemble = ensemble.to(DEVICE)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Evaluation routine analogous to the training routine."
]
},
{
"cell_type": "code",
"execution_count": null,
......@@ -427,11 +467,14 @@
"source": [
"# Calculate binary classification statistics for a test region\n",
"def eval_routine(model, region_path) :\n",
" print(f'=== Region {region_path.name} ===')\n",
" print('Loading dataset...')\n",
" testset = RegionDataset(region_path, all_labels, transform = transform)\n",
" testloader = DataLoader(testset, batch_size = batch_size, collate_fn = collate_fn, shuffle = False)\n",
" model.train(False)\n",
" \n",
" # Pass through model\n",
" print('Evaluating model...')\n",
" tp = tn = fp = fn = 0\n",
" for samples, labels in testloader :\n",
" samples, labels = samples.to(DEVICE), labels.to(DEVICE)\n",
......@@ -460,6 +503,13 @@
" stats.loc[region_path.name] = eval_routine(ensemble, region_path)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Calculate additional binary classification metrics."
]
},
{
"cell_type": "code",
"execution_count": null,
......@@ -477,6 +527,13 @@
" return stats"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Evaluation on a per-region basis."
]
},
{
"cell_type": "code",
"execution_count": null,
......@@ -489,6 +546,13 @@
"metrics"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Evaluation accumulated across all test regions."
]
},
{
"cell_type": "code",
"execution_count": null,
......@@ -507,7 +571,8 @@
" ).rename_axis('ground-truth', axis = 0).rename_axis('prediction', axis = 1),\n",
" annot = True,\n",
" cbar = False,\n",
" cmap = 'mako_r'\n",
" cmap = 'mako_r',\n",
" fmt = 'g'\n",
")\n",
"plt.savefig('confusion_matrix.png')\n",
"\n",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment