Commit 62e4070c authored by PauTheu's avatar PauTheu
Browse files
parents d4be135b 50dd7e41
......@@ -411,6 +411,10 @@
"metadata": {},
"outputs": [],
"source": [
"def output_to_onehot(x, dim = -1) :\n",
" maxidx = x.argmax(dim = dim, keepdims = True)\n",
" return torch.zeros_like(x).scatter(dim = dim, index = maxidx, value = 1)\n",
"\n",
"# Ensemble model consisting of pre-trained region models\n",
"class RegionEnsemble(nn.Module) :\n",
" \n",
......@@ -423,13 +427,18 @@
" model = torch.load(region_run_path / 'model_best.pt')\n",
" self.models[region_run_path.name] = model\n",
" \n",
" def confidence(self, x, dim = -1) :\n",
" x = x.softmax(dim = dim)\n",
" x = 1 + (x * torch.log(x)).sum(dim = dim, keepdims = True)\n",
" return x\n",
" \n",
" def forward(self, x) :\n",
" # Get region model outputs\n",
" x = torch.stack([model(x) for model in self.models.values()], dim = 1)\n",
" # Convert to probabilities\n",
" x = x.softmax(dim = 2)\n",
" # Average probabilities (confidence-weighted majority vote)\n",
" x = x.mean(dim = 1)\n",
" # Convert to confidence-weighted probabilities\n",
" x = x.softmax(dim = 2) * self.confidence(x, dim = 2)\n",
" # Sum weighted probabilities\n",
" x = x.sum(dim = 1)\n",
" return x"
]
},
......@@ -574,7 +583,7 @@
" cmap = 'mako_r',\n",
" fmt = 'g'\n",
")\n",
"plt.savefig('confusion_matrix.png')\n",
"plt.savefig(run_path / 'confusion_matrix.png')\n",
"\n",
"metrics"
]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment