nn_extracted_features.ipynb 58.3 KB
Newer Older
anastasiaslobodyanik's avatar
anastasiaslobodyanik committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "6b0d97d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "import seaborn as sns\n",
    "\n",
    "from sklearn.metrics import confusion_matrix, accuracy_score, recall_score, precision_score\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.preprocessing import MinMaxScaler\n",
    "from sklearn.tree import DecisionTreeClassifier\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "\n",
    "from tqdm import trange"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "02c7a9b0",
   "metadata": {},
   "source": [
    "# Windturbineausfall Vorhersage"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "7911fb2a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def prep_data(data):\n",
    "    # split data from labels\n",
    "    X,y = data.drop(columns=['label']), data.label\n",
    "\n",
    "    categorical = pd.DataFrame()\n",
    "    categorical_embedding_sizes = []\n",
    "    numerical = pd.DataFrame()\n",
    "\n",
    "    for column in X.columns:\n",
50
    "        # transform region into numerical vector\n",
anastasiaslobodyanik's avatar
anastasiaslobodyanik committed
51
52
    "        if column == 'region':\n",
    "            X[column].fillna('not specified', inplace=True)\n",
53
    "            # transform region into categories\n",
anastasiaslobodyanik's avatar
anastasiaslobodyanik committed
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
    "            c = X[column].astype('category')\n",
    "            # add numerical code of categories to dataset\n",
    "            categorical[column] = c.cat.codes.values\n",
    "            # calculate number of embeddings for categorical input, max 50\n",
    "            categorical_embedding_sizes.append((len(c.cat.categories),\n",
    "                                                min(50, (len(c.cat.categories) + 1) // 2)))\n",
    "        else:\n",
    "            numerical[column] = X[column]\n",
    "\n",
    "    return [torch.tensor(numerical.to_numpy(), dtype=torch.float),\n",
    "            torch.tensor(categorical.to_numpy(), dtype=torch.int64),\n",
    "            categorical_embedding_sizes,\n",
    "            torch.tensor(y.values, dtype=torch.int64).flatten()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "5c15a7f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ClassificationDataSet(Dataset):\n",
    "    def __init__(self, dataframe):\n",
    "        numerical_data, self.categorical_data, self.embeddings, self.labels = prep_data(data=dataframe)\n",
    "        self.scaler = MinMaxScaler()\n",
    "        self.numerical_data = torch.tensor(self.scaler.fit_transform(numerical_data), dtype=torch.float)\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.labels)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        return {'numerical': self.numerical_data[idx],\n",
    "                'categorical': self.categorical_data[idx],\n",
88
    "                'label': self.labels[idx]}"
anastasiaslobodyanik's avatar
anastasiaslobodyanik committed
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "859718c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Classifier(nn.Module):\n",
    "\n",
    "    def __init__(self, num_numerical_columns, output_size, layers_size, embeddings_size, p=0.4):\n",
    "        super().__init__()\n",
    "        \n",
    "        # a list of BatchNorm1d objects for all the numerical columns\n",
    "        self.batch_norm_num = nn.BatchNorm1d(num_numerical_columns)\n",
    "\n",
    "        # objects for all categorical columns\n",
    "        self.embeddings = nn.ModuleList([nn.Embedding(ni, nf) for ni, nf in embeddings_size])\n",
    "        # dropout for embeddings to avoid overfitting\n",
    "        self.embedding_dropout = nn.Dropout(p)\n",
    "        num_categorical_columns = sum(nf for ni, nf in embeddings_size)\n",
    "        #num_categorical_columns = 0\n",
    "\n",
    "        all_layers = nn.ModuleList()\n",
    "        input_size = num_numerical_columns + num_categorical_columns\n",
    "\n",
    "        for i in layers_size:\n",
    "            all_layers.append(nn.Linear(input_size, i))\n",
    "            # activation function\n",
    "            all_layers.append(nn.ReLU(inplace=True))\n",
    "            # batch normalization to the numerical columns\n",
    "            all_layers.append(nn.BatchNorm1d(i))\n",
    "            all_layers.append(nn.Dropout(p))\n",
    "            input_size = i\n",
    "\n",
    "        all_layers.append(nn.Linear(layers_size[-1], output_size))\n",
    "        all_layers.append(nn.Softmax(dim=1))\n",
    "        self.layers = nn.Sequential(*all_layers)\n",
    "\n",
    "    def forward(self, x_num, x_cat):\n",
    "        embs = []\n",
    "\n",
    "        x = self.batch_norm_num(x_num.float())\n",
    "\n",
    "        if x_cat is not None:\n",
    "            for i, e in enumerate(self.embeddings):\n",
    "                embs.append(e(x_cat[:, i].long()))\n",
    "\n",
    "            x_cat = torch.cat(embs, 1)\n",
    "            x_cat = self.embedding_dropout(x_cat)\n",
    "\n",
    "            x = torch.cat([x_cat, x], 1)\n",
    "\n",
    "        x = self.layers(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
149
   "execution_count": 277,
anastasiaslobodyanik's avatar
anastasiaslobodyanik committed
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
   "id": "337ffeb9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(model, train_dataloader, validation_data, epochs, lr):\n",
    "    aggregated_train_losses = []\n",
    "    aggregated_val_losses = []\n",
    "    aggregated_val_accs = []\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=0.0003)\n",
    "    loss_function = nn.NLLLoss()\n",
    "    #loss_function = nn.NLLLoss()\n",
    "\n",
    "    for i in trange(epochs):\n",
    "        oracle.train()\n",
    "        train_loss = 0\n",
    "        for i_batch, samples in enumerate(train_dataloader):\n",
    "            output = model(samples['numerical'], samples['categorical'])\n",
    "            single_loss = loss_function(output, samples['label'])\n",
    "            train_loss += single_loss.item()\n",
    "\n",
    "            optimizer.zero_grad()\n",
    "            single_loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "        train_loss = train_loss / len(train_dataloader)\n",
    "        aggregated_train_losses.append(train_loss)\n",
    "\n",
    "        oracle.eval()\n",
    "        with torch.no_grad():\n",
    "            output = oracle(validation_data.numerical_data, validation_data.categorical_data)\n",
    "            val_loss = loss_function(output, validation_data.labels).item()\n",
    "            aggregated_val_losses.append(val_loss)\n",
    "            output = torch.argmax(output, dim=1)\n",
    "            val_acc = accuracy_score(validation_data.labels, output)\n",
    "            aggregated_val_accs.append(val_acc)\n",
    "\n",
    "    return aggregated_train_losses, aggregated_val_losses, aggregated_val_accs"
   ]
  },
  {
   "cell_type": "code",
191
   "execution_count": 367,
anastasiaslobodyanik's avatar
anastasiaslobodyanik committed
192
193
194
195
   "id": "f2bdcc50",
   "metadata": {},
   "outputs": [],
   "source": [
196
    "train_df = pd.read_pickle('data/train_20_best_features_over_reg.pkl')\n",
anastasiaslobodyanik's avatar
anastasiaslobodyanik committed
197
198
199
200
201
202
203
204
205
206
207
208
    "train_df, test_df = train_test_split(train_df, test_size=0.2, random_state=1)\n",
    "train_df, val_df = train_test_split(train_df, test_size=0.2, random_state=1)\n",
    "\n",
    "train_dataset = ClassificationDataSet(dataframe=train_df)\n",
    "val_dataset = ClassificationDataSet(dataframe=val_df)\n",
    "test_dataset = ClassificationDataSet(dataframe=test_df)\n",
    "\n",
    "train_loader = DataLoader(train_dataset, batch_size=30, shuffle=True, drop_last=True)"
   ]
  },
  {
   "cell_type": "code",
209
   "execution_count": 368,
anastasiaslobodyanik's avatar
anastasiaslobodyanik committed
210
211
212
213
214
215
   "id": "1cbb6022",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
216
       "(21288, 22)"
anastasiaslobodyanik's avatar
anastasiaslobodyanik committed
217
218
      ]
     },
219
     "execution_count": 368,
anastasiaslobodyanik's avatar
anastasiaslobodyanik committed
220
221
222
223
224
225
226
227
228
229
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_df.shape"
   ]
  },
  {
   "cell_type": "code",
230
   "execution_count": 369,
anastasiaslobodyanik's avatar
anastasiaslobodyanik committed
231
232
233
234
235
236
237
   "id": "927c9ce4",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
238
      "100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [04:55<00:00,  2.95s/it]\n"
anastasiaslobodyanik's avatar
anastasiaslobodyanik committed
239
240
241
242
243
     ]
    }
   ],
   "source": [
    "oracle = Classifier(num_numerical_columns = train_dataset.numerical_data.shape[1], \n",
244
    "                    output_size = 2,\n",
anastasiaslobodyanik's avatar
anastasiaslobodyanik committed
245
246
247
248
249
250
251
252
253
254
255
256
    "                    layers_size = [train_df.shape[1] - 1, 40, 10, 5],\n",
    "                    embeddings_size = train_dataset.embeddings, \n",
    "                    p=0.4)\n",
    "\n",
    "training_epochs = 100\n",
    "lr = 0.0004\n",
    "\n",
    "train_losses, val_losses, val_accs = train(oracle, train_loader, val_dataset, training_epochs, lr)"
   ]
  },
  {
   "cell_type": "code",
257
   "execution_count": 370,
anastasiaslobodyanik's avatar
anastasiaslobodyanik committed
258
259
260
261
262
263
264
265
266
   "id": "2f34c976",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<AxesSubplot:>"
      ]
     },
267
     "execution_count": 370,
anastasiaslobodyanik's avatar
anastasiaslobodyanik committed
268
269
270
271
272
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
273
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAD5CAYAAAAqaDI/AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAABE8klEQVR4nO3dd3hUVfrA8e9JJ70CoSb0ToDQRaSKooCKgsoKlrWuuhYWXF3bLr9FZVGxYFfsIgI2UHovEnqVGiAESEjvbc7vjzNpMCFlQgKZ9/M8eWZuP3cu3PeeepXWGiGEEI7LqbYTIIQQonZJIBBCCAcngUAIIRycBAIhhHBwEgiEEMLBSSAQQggH52LPxkqpQOA7IAyIBm7TWifZWK8A2G2dPKG1HmWdHw58CwQBW4G/aK1zyztucHCwDgsLsyfpQgjhcLZu3XpOax1y/nxlTz8CpdSrQKLWerpSaioQoLWeYmO9dK21t435c4H5WutvlVLvATu11rPLO25kZKSOioqqcrqFEMIRKaW2aq0jz59vb9HQaGCO9fscYEwlEqSAwcC8qmwvhBCietgbCBporU9bv58BGpSxnodSKkoptUkpNcY6LwhI1lrnW6djgMZ2pkcIIUQllVtHoJRaBjS0sejZkhNaa62UKqucqbnW+pRSqgWwQim1G0ipTEKVUvcD9wM0a9asMpsKIYS4iHIDgdZ6aFnLlFJnlVKhWuvTSqlQIK6MfZyyfh5VSq0CugE/AP5KKRdrrqAJcOoi6fgA+ABMHUF56RZCVK+8vDxiYmLIzs6u7aSIcnh4eNCkSRNcXV0rtL5drYaAn4CJwHTr54/nr6CUCgAytdY5SqlgoD/wqjUHsRIYi2k5ZHN7IcTlISYmBh8fH8LCwjBVfOJypLUmISGBmJgYwsPDK7SNvXUE04FhSqlDwFDrNEqpSKXUR9Z12gNRSqmdwEpgutZ6n3XZFOBJpdRhTJ3Bx3amRwhxiWRnZxMUFCRB4DKnlCIoKKhSOTe7cgRa6wRgiI35UcB91u8bgM5lbH8U6GVPGoQQNUeCwJWhstfJoXoWL9x+ii83Ha/tZAghxGXFoQLBr7tP8/nG6NpOhhCiCpKTk3n33XertO31119PcnLyRdd5/vnnWbZsWZX2f76wsDDOnTtXLfuqCQ4VCBr71+N0srR4EOJKdLFAkJ+fb3N+oUWLFuHv73/RdV5++WWGDi2zkWSd5lCBINTPg7ScfFKz82o7KUKISpo6dSpHjhwhIiKCyZMns2rVKgYMGMCoUaPo0KEDAGPGjKFHjx507NiRDz74oGjbwif06Oho2rdvz1//+lc6duzI8OHDycrKAmDSpEnMmzevaP0XXniB7t2707lzZw4cOABAfHw8w4YNo2PHjtx33300b9683Cf/mTNn0qlTJzp16sQbb7wBQEZGBiNHjqRr16506tSJ7777rugcO3ToQJcuXXj66aer9fe7GHubj15RQv3rAXA6ORvfhhVrXyuEuNBLP+9lX2xqte6zQyNfXrixY5nLp0+fzp49e9ixYwcAq1atYtu2bezZs6eomeQnn3xCYGAgWVlZ9OzZk1tuuYWgoKBS+zl06BDffPMNH374Ibfddhs//PADEyZMuOB4wcHBbNu2jXfffZcZM2bw0Ucf8dJLLzF48GCeeeYZfvvtNz7++OINHbdu3cqnn37K5s2b0VrTu3dvBg4cyNGjR2nUqBG//vorACkpKSQkJLBgwQIOHDiAUqrcoqzq5FA5gsb+HgDEpmTVckqEENWhV69epdrKz5o1i65du9KnTx9OnjzJoUOHLtgmPDyciIgIAHr06EF0dLTNfd98880XrLNu3TrGjx8PwIgRIwgICLho+tatW8dNN92El5cX3t7e3Hzzzaxdu5bOnTuzdOlSpkyZwtq1a/Hz88PPzw8PDw/uvfde5s+fj6enZyV/japzrByBX3GOQAhRdRd7cq9JXl5eRd9XrVrFsmXL2LhxI56enlxzzTU229K7u7sXfXd2di4qGiprPWdn53LrICqrTZs2bNu2jUWLFvHcc88xZMgQnn/+ef744w+WL1/OvHnzePvtt1mxYkW1HrcsDpUjqO/jjpOC2GTJEQhxpfHx8SEtLa3M5SkpKQQEBODp6cmBAwfYtGlTtaehf//+zJ07F4AlS5aQlHTB61dKGTBgAAsXLiQzM5OMjAwWLFjAgAEDiI2NxdPTkwkTJjB58mS2bdtGeno6KSkpXH/99bz++uvs3Lmz2tNfFofKEbg4O9HQ10OKhoS4AgUFBdG/f386derEddddx8iRI0stHzFiBO+99x7t27enbdu29OnTp9rT8MILL3D77bfzxRdf0LdvXxo2bIiPj0+Z63fv3p1JkybRq5fpN3vffffRrVs3fv/9dyZPnoyTkxOurq7Mnj2btLQ0Ro8eTXZ2NlprZs6cWe3pL4tdL6apLfa8mOaW2RtwdVZ8e3/fak6VEHXb/v37ad++fW0no1bl5OTg7OyMi4sLGzdu5KGHHiqqvL7c2LpeZb2YxqFyBGCakO4+VakRsIUQAoATJ05w2223YbFYcHNz48MPP6ztJFULhwsEjf3rsWTfWSwWjZOTjJsihKi41q1bs3379tpORrVzqMpiMDmC3HwLCRm5tZ0UIYS4LDhcIGhU2KlMKoyFEAJw4EAQK30JhBACcMBAEOpn7V0sfQmEEAJwwEAQ6OWGu4uTFA0J4QC8vb0BiI2NZezYsTbXueaaayivOfobb7xBZmZm0XRFhrWuiBdffJEZM2bYvR97OVwgUErRyL8esSlSNCSEo2jUqFHRyKJVcX4gqMiw1lcShwsEYIqHpGhIiCvL1KlTeeedd4qmC5+m09PTGTJkSNGQ0T/++OMF20ZHR9OpUycAsrKyGD9+PO3bt+emm24qNdbQQw89RGRkJB07duSFF14AzEB2sbGxDBo0iEGDBgGlXzxja5jpiw13XZYdO3bQp08funTpwk033VQ0fMWsWbOKhqYuHPBu9erVREREEBERQbdu3S469EZFOFw/AjAVxusOXTlvDxLisrN4KpzZXb37bNgZrpte5uJx48bx97//nUceeQSAuXPn8vvvv+Ph4cGCBQvw9fXl3Llz9OnTh1GjRpX53t7Zs2fj6enJ/v372bVrF927dy9aNm3aNAIDAykoKGDIkCHs2rWLxx57jJkzZ7Jy5UqCg4NL7ausYaYDAgIqPNx1obvuuou33nqLgQMH8vzzz/PSSy/xxhtvMH36dI4dO4a7u3tRcdSMGTN455136N+/P+np6Xh4eFT0V7bJsXIES/4FPz1KIz8P4tKyySuw1HaKhBAV1K1bN+Li4oiNjWXnzp0EBATQtGlTtNb885//pEuXLgwdOpRTp05x9uzZMvezZs2aohtyly5d6NKlS9GyuXPn0r17d7p168bevXvZt2/fRdNU1jDTUPHhrsEMmJecnMzAgQMBmDhxImvWrClK45133smXX36Ji4t5du/fvz9PPvkks2bNIjk5uWh+VTlWjiDxKCQeo1FkPSwazqZm0ySg5sb8FqLOuMiT+6V06623Mm/ePM6cOcO4ceMA+Oqrr4iPj2fr1q24uroSFhZmc/jp8hw7dowZM2awZcsWAgICmDRpUpX2U6iiw12X59dff2XNmjX8/PPPTJs2jd27dzN16lRGjhzJokWL6N+/P7///jvt2rWrclodK0fg5g05acVvKpMKYyGuKOPGjePbb79l3rx53HrrrYB5mq5fvz6urq6sXLmS48ePX3QfV199NV9//TUAe/bsYdeuXQCkpqbi5eWFn58fZ8+eZfHixUXblDUEdlnDTFeWn58fAQEBRbmJL774goEDB2KxWDh58iSDBg3ilVdeISUlhfT0dI4cOULnzp2ZMmUKPXv2LHqVZlU5Vo7A3Qdy02gkfQmEuCJ17NiRtLQ0GjduTGhoKAB33nknN954I507dyYyMrLcJ+OHHnqIu+++m/bt29O+fXt69OgBQNeuXenWrRvt2rWjadOm9O/fv2ib+++/nxEjRtCoUSNWrlxZNL+sYaYvVgxUljlz5vDggw+SmZlJixYt+PTTTykoKGDChAmkpKSgteaxxx7D39+ff/3rX6xcuRInJyc6duzIddddV+njleRYw1AvexE2vEX6P87Q6cUlTBnRjoeuaVnt6ROiLpJhqK8slRmG2rGKhtx9wJKPt3MBvh4u0qlMCCFwtEDgZn2TUE6a6VQmRUNCCOFggcDdGghyCwOBVBYLURlXYlGyI6rsdbIrECilApVSS5VSh6yfAWWsV6CU2mH9+6nE/M+UUsdKLIuwJz3lcjfjjpCTRgNf05dACFExHh4eJCQkSDC4zGmtSUhIqFQnM3tbDU0Flmutpyulplqnp9hYL0trHVHGPiZrras+CEhlFOYIctKp7xNCQkYu+QUWXJwdK2MkRFU0adKEmJgY4uPjazspohweHh40adKkwuvbGwhGA9dYv88BVmE7EFweStQRhPg0QWtIyMilga993bOFcASurq6Eh4fXdjLEJWDvo3ADrfVp6/czQIMy1vNQSkUppTYppcact2yaUmqXUup1pZS7rY0BlFL3W/cRVeUnkqI6gnRCfMyh4tNyqrYvIYSoI8rNESillgENbSx6tuSE1lorpcoqPGyutT6llGoBrFBK7dZaHwGewQQQN+ADTG7iZVs70Fp/YF2HyMjIqhVSFtURpBJSXwKBEEJABQKB1npoWcuUUmeVUqFa69NKqVAgrox9nLJ+HlVKrQK6AUdK5CZylFKfAk9X9gQqpUQdQYi3CQRSYSyEcHT2Fg39BEy0fp8IXDAQuFIqoLDIRykVDPQH9lmnQ62fChgD7LEzPRfn6mU+c9KkaEgIIazsrSyeDsxVSt0LHAduA1BKRQIPaq3vA9oD7yulLJjAM11rXTi261dKqRBAATuAB+1Mz8U5OZkK49x0PFyd8fVwkUAghHB4dgUCrXUCMMTG/CjgPuv3DUDnMrYfbM/xq8TdG3JSAQjxcSc+XQKBEMKxOV4DencfyEkHTCCIS5VAIIRwbI4XCKzvJACo7+MhOQIhhMNzvEDgbuoIwFo0JHUEQggH55iBwJojCPFxJzO3gIyc/FpOlBBC1B4HDQTWHEFRXwLJFQghHJfjBQK30q2GQPoSCCEcm+MFgsI6Aq2p7yuBQAghHDAQeIMlH/Kzi4qG4mWYCSGEA3PAQOBrPnPSCfB0w9lJSRNSIYRDc7xA4FY8AqmTkyLY2006lQkhHJrjBYIS7yQA6VQmhBAOGAiK31sM0qlMCCEcMBAUv5MATF8CCQRCCEfmeIGgxHuLweQIzqXnUGCp2kvPhBDiSud4gaCojqA4EFg0JGbk1mKihBCi9jhgIChdR1BfehcLIRyc4wUCVy9AlXonASAth4QQDsvxAoGTU6l3EhQGgrhU6V0shHBMjhcIwBQP5ZYOBJIjEEI4KgcNBMXvJPB0c8HbXV5iL4RwXI4ZCNy8i+oIQDqVCSEcm2MGghI5AjCdyuTlNEIIR+W4gSC3dI7gnAQCIYSDctxAUDJH4OPOmdRstJbexUIIx+OYgaBE81GAsCBPMnMLpOWQEMIhOWYgKMwRWHMAYcFeAESfy6zNVAkhRK1w0EDgDboA8k0nshbBZtiJY+fSL7aVEELUSQ4aCApfV2mKhxoH1MPVWXH0XEYtJkoIIWqHXYFAKRWolFqqlDpk/QwoY71mSqklSqn9Sql9Sqkw6/xwpdRmpdRhpdR3Sik3e9JTYW6lB55zdlI0D/LiWLwEAiGE47E3RzAVWK61bg0st07b8jnwmta6PdALiLPOfwV4XWvdCkgC7rUzPRXjXvqdBABhQV5EJ0ggEEI4HnsDwWhgjvX7HGDM+SsopToALlrrpQBa63StdaZSSgGDgXkX2/6SKByKukRfghYhXkQnZMoLaoQQDsfeQNBAa33a+v0M0MDGOm2AZKXUfKXUdqXUa0opZyAISNZa51vXiwEal3UgpdT9SqkopVRUfHy8fam2kSMID/YiN99CbHKWffsWQogrTLmBQCm1TCm1x8bf6JLradMby9bjtAswAHga6Am0ACZVNqFa6w+01pFa68iQkJDKbl6aW+n3FoMJBADHpMJYCOFgXMpbQWs9tKxlSqmzSqlQrfVppVQoxWX/JcUAO7TWR63bLAT6AJ8A/kopF2uuoAlwqgrnUHlFOYLUolktCvsSJGRwNXYGGiGEuILYWzT0EzDR+n0i8KONdbZgbviFd9fBwD5rDmIlMLac7aufjTqCEB93vNycOSoth4QQDsbeQDAdGKaUOgQMtU6jlIpUSn0EoLUuwBQLLVdK7QYU8KF1+ynAk0qpw5g6g4/tTE/FFL2usriOQClFWLCXFA0JIRxOuUVDF6O1TgCG2JgfBdxXYnop0MXGekcxzUlrVtHrKkv3JA4P9mJXTEqNJ0cIIWqTY/YshgtGIAVTTxCTlEluvqWWEiWEEDXPgQNB8XuLC4UFe2HRcCJRBp8TQjgOBw4EF+YIpAmpEMIROW4gKKOOAGQUUiGEY3HcQGAjR+Dv6Uagl5vkCIQQDsVxA4GHH2Rf2EIoLMhTAoEQwqE4biDwDISsxKK3lBUKD/aWQCCEcCgOHAiCzBvKckvf9FvV9+Zsag6JGbm1lDAhhKhZDhwIgs1nZkKp2b3Czbt1Nh9NOH8LIYSokxw4EASZz8xzpWZ3aeJPPVdnNkogEEI4CMcNBF6FOYLEUrNdnZ2IDAtgkwQCIYSDcNxAUJgjyDh3waK+LYM4eDadc+k5NZwoIYSoeRIIMi988u/bwiyTXIEQwhE4biDw8AMnlwvqCAA6N/bDy82ZjUckEAgh6j7HDQRKmVyBjRyBi7MTPcMDJUcghHAIjhsIwASCDNs3+74tgjgSn0FcanYNJ0oIIWqWBAIbOQIwFcaANCMVQtR5EgjKCAQdQn3xcXeR4iEhRJ3n2IHAK9hmZTGYeoJe4YFsOppoc7kQQtQVjh0IPIMgKxkK8m0u7tsyiGPnMjiTIvUEQoi6y8EDQTCgISvJ5uLe4aaeYPMxKR4SQtRdDh4IAs1nGfUE7UN98HZ34Y9jUjwkhKi7HDsQFI03VHY9QY/mARIIhBB1mmMHgosMM1GoV3ggh+LSSZBxh4QQdZSDBwJrjsDGwHOF+rQwxUdboiVXIISomxw8EBTWEZR9k+/c2B93Fyc2S/GQEKKOcuxA4OIObj5l1hEAuLk40b2Z1BMIIeouuwKBUipQKbVUKXXI+hlQxnrNlFJLlFL7lVL7lFJh1vmfKaWOKaV2WP8i7ElPlXiV3bu4UK/wQPadTiU1O6+GEiWEEDXH3hzBVGC51ro1sNw6bcvnwGta6/ZALyCuxLLJWusI698OO9NTeZ5BF60jAOjdIhCtIUrqCYQQdZC9gWA0MMf6fQ4w5vwVlFIdABet9VIArXW61jrTzuNWH8/gcnME3ZoG4OqspJ5ACFEn2RsIGmitT1u/nwEa2FinDZCslJqvlNqulHpNKeVcYvk0pdQupdTrSil3O9NTeZ5BF60sBqjn5kyXJv5STyCEqJPKDQRKqWVKqT02/kaXXE9rrQFtYxcuwADgaaAn0AKYZF32DNDOOj8QmHKRdNyvlIpSSkXFx8dX4NQqyCvIVBZrW0kv1is8kN0xKWTm2h6XSAghrlTlBgKt9VCtdScbfz8CZ5VSoQDWzzgbu4gBdmitj2qt84GFQHfrvk9rIwf4FFN/UFY6PtBaR2qtI0NCQip9omXyDIL8bMi7eGlV/5bB5Fs06w5dvD5BCCGuNPYWDf0ETLR+nwj8aGOdLYC/Uqrw7j0Y2AdFwQOllMLUL+yxMz2VV4FOZWAqjP3qubJ4z5kaSJQQQtQcewPBdGCYUuoQMNQ6jVIqUin1EYDWugBTLLRcKbUbUMCH1u2/ss7bDQQD/7EzPZVXgWEmAFydnRjeoQHL9p8lJ7+gBhImhBA1w8WejbXWCcAQG/OjgPtKTC8FuthYb7A9x68WRQPPlT/U9HWdG/L91hg2HE5gULv6lzhhQghRMxy7ZzFUOEcA0L9VMD7uLizafbrcdYUQ4kohgaAwEJRTRwDg7uLM0A4NWLr/LHkFlkucMCGEqBkSCDz8wMmlQjkCgOs6NSQ5M09eai+EqDMkEChl7VRWsWahV7cJwdPNWVoPCSHqDAkEUKHexYU8XJ0Z3K4+v+85Q4Hl4p3QhBDiSiCBAKyBoOJFPdd3DiUhI5fVB231nxNCiCuLBAKo0AikJQ3r0IAmAfV4c9khdDlDUwghxOVOAgGYvgQZFX+6d3V24tHBrdgZk8KKA5IrEEJc2SQQAAS1huwUSK14/4CbuzehWaAnb0iuQAhxhZNAABDa1Xye3lHhTQpzBbtPpbBsv+QKhBBXLgkEAA07AwpO76zUZjd1a0xYkCevLz0ouQIhxBVLAgGAuzcEt4HYHZXazMXZiUcHt2bf6VQ+WR99SZImhBCXmgSCQqFdK50jABjTrTHDOzTg37/sY+H2U5cgYUIIcWlJICjUKALSYiG9cuX9zk6KWbd3o0+LQJ7+ficrpRWREOIKI4GgUGiE+axk8RCY3sYf3hVJu1AfHvpqK3tOpVRr0oQQ4lKSQFCoYWfzWYXiIQAfD1fm3N0LHw9Xnv9xj1QeCyGuGBIICnn4QlCrSjUhPV+QtztPD2/DthPJ/CrvLBBCXCEkEJQUGlHlHEGhsT2a0q6hD9MXHyA7T15pKYS4/EkgKCm0K6SchIwSA9BZKnczd3ZSPDeyAzFJWXy2Ibp60yeEEJeABIKSGkWYz9PbzeeaGfBqC8hJq9RurmodzJB29XlnxWHiUrOrN41CCFHNJBCU1LCL+Ty9E3bPgxX/huxkSDpe6V09c317svMLGPjaKqbM28WumORqTaoQQlQXCQQl1fOHgHDY/QMsfBh8m5j5aZWv+G1V35sfH7mK0RGN+GlnLKPeXs/0xQeqN71CCFENJBCcr1EExO0F30Yw/iszLzW2Srvq0MiX6bd0YfOzQ7i1RxPeW31EOpwJIS47EgjO12IQeAbDHXOhfgczrwo5gpJ8PVz595hOtGvow1Pf75R6AyHEZUUCwfl6TISnD0JIG3BxA6+QKucISvJwdebtO7qRmZvPE3N3YCnxvmOLRXM4Lp0fd5ySugQhRI1zqe0EXJacnIu/+4TanSMo1Kq+Dy/e2JGp83czZOZq3Jyd0GhOJWWRkWuaqfp4uPD736+mkX+9ajmmEEKURwJBeXwbQUpMte1uXM+mJGTkFj35KxS9w4Po3MSPRn71eOCLKJ7+fidf3tsbJycFQIFFk51XgJe7XC4hRPWTO0t5fELh5B/VtjulFI8MalXm8n/d0IGp83fz6YZo7r0qnH2xqTzx3Q6Ss3JZ8sRA/Oq5VltahBAC7KwjUEoFKqWWKqUOWT8DbKwzSCm1o8RftlJqjHVZuFJqs1LqsFLqO6WUmz3puSR8G0FWIuTVTAXvuJ5NGdq+Pq/8doD/Lt7PmHfWk5CRQ1xaDm8uO1QjaRBCOBZ7K4unAsu11q2B5dbpUrTWK7XWEVrrCGAwkAkssS5+BXhda90KSALutTM91c8n1HxWUz1BeZRS/PfmLni7u/D+6qMMblefJU8M5PZezZizMZqDZyvXy1kIIcpjbyAYDcyxfp8DjCln/bHAYq11plJKYQLDvEpsX/N8azYQAIT4uDPn7l68N6EHsyd0J9DLjcnD2+Lt7sKLP+2VIa6FENXK3kDQQGtdeIc8AzQoZ/3xwDfW70FAstY63zodAzQua0Ol1P1KqSilVFR8fLw9aa4cn0bmsxqakFZG5yZ+jOjUEBMvIcDLjaeHt2HDkQR+2hnL4bh0fttzml93yXDXQgj7lFtZrJRaBjS0sejZkhNaa62UKvNRVSkVCnQGfq9sIq37/wD4ACAyMrLmHolrIUdQljt6N+frP07y+Lc7Ss13de7B8I62LpEQQpSv3ECgtR5a1jKl1FmlVKjW+rT1Rn+x8RNuAxZorfOs0wmAv1LKxZoraAJcfm9/9/AHl3qQWvuBwNlJ8eb4CH7ZGUtYsBctQ7yZ8sMuXvxpL/1bBUvzUiFEldhbNPQTMNH6fSLw40XWvZ3iYiG0Keheiak3qMj2tUMpkytIq9miobK0aeDDk8PbcnP3JnRt6s+0mzoRm5LNG8sOFq2TkpnHgu0xZOXKi3GEEOWzNxBMB4YppQ4BQ63TKKUilVIfFa6klAoDmgKrz9t+CvCkUuowps7gYzvTc2n4NLoscgS29GgeyO29mvLJ+mj2xaby667TDJm5mie+28m1b6xhw5FzNrfT2gxrUWCRimchHJ26ElugREZG6qioqJo74A/3wcnN8PfdNXfMSkjOzGXI/1aTk28hPSefTo19mdQvnLdXHCI6IZPbIpswqG19GgfUw9fDlSX7zjA3KobDcelM6hfGi6M61vYpCCFqgFJqq9Y68vz5UqhcET6hkHYGtDZFRZcZf083Xh7diWcX7ubZ69tzd/8wXJyduKFLKK8vO8hHa48xN6r0MBk9mgcwuF195myMZmSXUHqGBdZS6oUQtU1yBBWxaTb8NhUmHwGv4Jo7biVprYuam5aUlp3HycQsYpIyiUvLoU+LIFrV9yYzN59r31iDi5MTix4bQD03Zxt7FULUFWXlCGQY6ooo7F1cw30JKstWEADw8XClQyNfhndsyIQ+zWlV3xsATzcXXrmlC8fOZTBz6Z9F6ydl5JJfYKmRNAshap8UDVWEr7VTWdppCO1Su2mpZv1aBnNn72Z8tO4Ye2NTOXg2jXPpuYT6eXBP/3DG92qKj4crBRZNbHIWwd7uNnMOOfkF/HEskeX748jJt/D4kNY09PMoWn4yMZOtx5MY0r4+Ph4ycJ4QlxMJBBVxheQIquqZ69uzNzaV9Jx8BrerT4sQb1b9Gce0RfuZteIQoX4eRCdkkptvoYGvO+/e2YMezc34gpm5+fxvyUG+/eMEGbkFuLuYTOYvO2OZen07RnYO5Z2Vh5mz4Ti5BRb86rly31XhTOofJgFBiMuE1BFUREEe/DsEBv4DBv2z5o5by3aeTOazDdGkZefRIsSbJgH1+GjtMU6nZPH8DR1o29CXyfN2cjwhk5u7NWZkl1D6tQwmLi2bZ+bvZsORBFycFAVac2uPJtzQpRFzNkSz/EAcfvVceWBgCyb1C8PTzYW07Dw+WnuM76NOMu3mzgxqW7+2T1+IOqesOgIJBBU1ow20Hg6j367Z415mUjLzeGLuDlYcMJ3ImwV68urYLvRpEVRqPa0130fFsP1kEnf1DaN9qG/Rst0xKby+7CArDsQR7O3OjV1D+XFHLIkZufh7uuKkFL89PoD6vh4IIaqPBAJ7vT/QtBia8EPNHvcyZLFoPlx7lKTMPB4b0gpPt6qVMG49nsirv/3J5mOJ9GsZxJQR7fByd+aGt9YR2TyQz+/pVfSWtoT0HAI83YqmC+05lUJ9X3fq+0jQEKI80o/AXr6NIOl4bafisuDkpHhgYEu799OjeSDf3t+H5Mw8AryK30n0/A0d+eeC3Xy07ij9WgYze9URFu05TZ/wIN6+oxtB3u5orfl43TGmLdpPl8Z+LHykf5mtpoQQFyeBoKJ8QuHExtpORZ2jlCoVBABu79WUNQfjmb74ABYNPu4u3NqjCT/uiOXGt9bx1h3dWbA9hi83naBNA292xqTw6+7T3NCl0SVLZ3ZeAXOjTvLTjljaNvRhcLv69GsZLH0vRJ0ggaCifEMhKwnyssC1Xm2npk5TSjH9ls7kWzTdmvnzl77N8fVw5a6+YTzwxVZumb0BgAeubsHT17blxrfW8epvfzK8Q0PcrK2WfttzhtSsPG7p0QRnJ9s5hdx8C8mZuReti8jOK+CLjcf5YO1R4tNyaNPAmwXbT/HV5hPUc3Vmzj296BUuvbLFlU3qCCpqx9ew8CG4cx60HlazxxZFEjNymfbrfnqHB3Jbz6YArPozjkmfbuGFGztwd/9wPlhzhP9bdACADqG+/HtMR3o0L32zXn/4HP9auIcTiZk8MawNDw5sWSpgWCyan3bG8upvB4hNyeaqVsH8bXAreocHkltg4Y9jiTw1dyftQn35/J5e1XZ+Z1OzOZOSTdem/tW2TyEKSWWxvTIT4ZMRkHwcbvsC2gyv2eOLMmmtmfDxZvbFpnJTtyZ8sv4YI7uEMrxDA/676ABnUrMZ0DqYNg18CAv2Iio6kR93xNI8yJM2DXxYuu8sPcMCeGlUJxIyctgbm8ri3afZGZNCp8a+PDeywwWtogBmLT/EzKUHWfHUQFqEeBfNzyuwsPV4EisPxLHxaAL9WgYz+dq2ZeZMCmXnFTBy1lqiEzJ5987uXFviZUO7Y1JYsP0UXu7O+NVzpWWIN9e0DXGYehGLRfPk3B3k5FsY0akhg9tJx8SqkEBQHTLOwZc3w9m9cMtH0PGmyu9Da9j3I9RvDyFtqz+NDmrPqRRueGsdAHf2bsbLozvh7KTIyMnn3VWHWb4/juiEDLLzLLg5O/HgNS15+JqWuLs4sXDHKZ5fuJe0nPyi/YUHe/HIoFbc3K3xBS2VCsWlZdN/+gom9GnOCzeaEVwT0nMY8+56TiZm4eqsaNPAh72xqQzr0IA3x0cUtbA6di4DXw8Xgrzdi/b338X7eX/1UVoEexGTlMWnd/ekf6tgftgawzMLdoOGPIuFwv+yr43twq2RTYu211qzNzaVliHel33dxYbD5/hk/TH6tQxmfK+m5bY8W7A9hie+24mvhwup2fm4OTvx5HCTkxMVJ4GgumSnwNfjzLDUf1kALa6p+LaWAvjtGfjjfWg5BP4y/5Il0xG9t/oIzkpx34Bwm0/KFosmLi0HF2dFcIkbMEBMUiar/oynRbAXHRr54u/pdsH2tjz+7XZW7I9j0z+H4OnmzF8/38qag/G8dmuXoqfWORuieennvXRq7Mew9g34dfdpDpxJw8fDhVnjuzGoXX22nUhi7OwNjOvZlCkj2jHu/U2cTMpkRMeGzN9+in4tg3j7ju7413MlLTufB7/cys6YZH559CpahHijteblX/bx6fpo3F2c6NcyiCHtGzC2RxM8XC+foJCZm8/0xQf4fONxfDxcSMvOJ8DTlYn9wugZFkhDPw9C/TxKBYbM3HwGz1hNfV935j/Uj50xyby14jDrDp3j9yeupmWJ3Jg9zqRkMzfqJHfX4V7vEgiqU24GvDcALHnw0EZwr8A/xLxsWPAA7FsIvk0g8xxMiS5d8bzpPdAF0PeRS5VyUc22nUji5nc38O8xnXBS8OyCPTw3sj33DWhRar1l+87y6DfbycorILJ5ACM6NWT+tlPsP5PKU8PaMH/7KXLyLPz29wH4eLgSl5rNre9v5HhCJnf3D+PZ69vj4lw8RuSZlGxGvLmGpgGe/PBQP2Ys+ZMP1hzltsgmeLm7sOJAHMcTMunYyJf3JvSgaaDnRc8jLi2b/y46QFxaNo3969HY35O2Db3p0TyQEB93cvILWL4/ju+jTlLPzZk3x3fDtUR6svMKiE3OKlVEdr4TCZn85ZPNnEjM5O5+4Uy+ti17Y1OYveoIyw8Uv+VWKZjYN4znRppzfn3pQd5cfoh5D/Yl0jpcenxaDoNmrKJXeCCfTOpZqWtW1vmPf38TR89lMKB1MJ9M6lnq/OoKCQTV7fhG+PQ66HkfjJxR9nppZ+DAr7Dtczi9A4b/B4Lbwte3woT50GqIWS8/B15rbf4X/OMoOF0+T3GibFprRr29nqTMXM6l59AzLJA5d/eyWZx0Lj2HvAILoX4m+GflFvCPH3bx804zhtWX9/bmqtbFw5zHpWVzJC6Dvi0vrJ8A+H3vGR74YivtGvpw4Ewaf+nTnJdHd0Qphdaa5fvjeGLuDpydFLPGd6NZoCfbTyax51Qq4cFeDOvQgAa+HizZe4ap83eTkZNP+1BfTiVnEZ+WU3Sc8GAvkjNzScrMI8THnfi0HCb2bc5LozsB5on9ro//YOuJJP4+pA2PDm51wfnn5BcwdvZGjidk8MFdkRfUucQmZ3E8IZMzqVlsPprIt1tOMqB1MM+ObM+Yd9YztH0D3r6je6lt3l99hP8uPsCce3oxsE1IRS/ZBRIzchn/wUZikrK4s3czPlx7jHGRTZl+S+c6VwcjgeBSWDwFNr8Hk36FsKtKL4vdAUueg+h1gIbAFjD4Oeh0i8lRvBIGve6Ha6eZ9Q8uMcEB4L4V0KRHDZ6IsMf3USeZPG8XAZ6u/Pb3q2lQiaExtNZ8uek4eQWae64Kr/Sxn12wm682n+D2Xk2ZNqbzBTfg6HMZPPDFVv48m1Y0z9VZkVdg/t+3qu/N4bh0Ojby5c3xEbSq7wOYJ/y9salERSeyJToJD1cnxvZowoDWIUxfvJ8P1x7j1Vu6MCqiEfd8toVNRxPo3yqYtYfOMahtCG+M64afZ3Hxygs/7mHOxuN8eFckwzo0KPe85m45ybMLd2PR4OykWPHUQJoElM7V5OQXMPz1Nbg6O7H48QG4OjuRm28hMzcfL3cXXJ2dsFg0CRm5nE7JQmvo1NivVKX94bg0Hv92B4fj0vl0Uk/6tQpm5pI/mbXiME8MtR3UKqKsd4PUNgkEl0JuBszuByi48U1T+evqCSunwR8fgGewyTG0v9FUDpf8hzFnFKTHwSObzPTCh00lcm4GDHoWBk6ulVMSlZedV8DT3+9kXM+mDGhd9SfTqsjNN01Z+7UMKvOGlZmbz+cbj+NXz5VuzfxpXd+HI/HpLN13lrWH4ukZFsijg1sX9cEoT36BhUmfbuGPY4l0aeLH1hNJzLytK2MiGvPlpuO8/Ms+QrzdmdgvjLE9mrDpaCKPfL2N+64K57kbOlT43LZEJ/K3r7cxsV8YD1/TyuY6S/ae4f4vtnJVq2CSs3I5eCadXOu7NNxdnNCaommAAE9XBrYJoWmgJ0v2nuXPs2m4uTjx/l96FA10qLXmqbk7mb/9FM5OimBvNxr4etAzLJDB7erTMyzwgt/qcFwaz8zfzeG4dLLyCsjJtzAmojEzbu1abmuxiopJyiQqOokbuzaq8j4lEFwqx9aalkQFuWZaOZmWQT3vMzmAev62t1s/C5b+C57YB14hMKMVtLkO4g+YeoN7fquxUxCispIychn1zjpOJmbx35s7c3uvZkXLtp1I4v9+3U/U8STcnJ1wcoJ2DX2Z+0DfCgebQuU9WWutuf+LrURFJ9KpsR8dGvlS38eDzJx80q2twEL9PAj1r0d2XgGr/4xn1cF4kjJz6dk8kJFdQrmuU8MLOhXm5ltYuP0UJxIziUvLJiYpi6jjSeTmW/B2d+H6zuYlT50b+/H91hhe+HEv9dycGdk5FE83Z5Iyc5kbFcNdfZvz0qiOVcodFFg0qw/G8fPO0/xxLJFTyVkA/PLoVXRq7Ffp/YEEgksrPR7O7oFzByH5BHS8ufyinbN7TW5i1Num1/KXt8D4byBmC6x/01Qke/hefB9C1KJTyVkcP5dBv1a2X9/655k0vvnjBDtjkpk1vlu5FdY1pcCiSc/Jx69e5VoGZebms+FwAkv2neHnnafJyiugsX89TiVn0a9lEK+PiyhVLPjfRft5f81RJl/blkcGmRxNUkYuW48nseloAhuPJgBwZ+/m3NStMfXcnNFacyQ+nd/3nuXrzSc4lZxFoJcbfVoE0isskF7hQbRt6CM5ArgMA0FVaA0z20OzPuDmDXsXwuTDcCoKPhsJ47+GdiNrO5VCCBtSs/P4cfspft19mqvbhPDA1S0vuDlbLJqnvt/Jgu2n6NsiiGPnMjiTmg2Am4sT3Zv5k5adz97YVPw9XYlo6s/Ok8kkZeYB0L9VEHf2bs6wDg2qrQWTjD56uVEKWg42LYqUE7QdAa4e0KSXCQyHl0sgEOIy5evhyl/6hvGXvmFlruPkpHjlli4oYN/pVPq2DKJdQx86N/Gje7MAPFxNDmBLdBKfrDvGkfh0hrZvQM+wQPq2DKrRHJQEgtrUcjDs+Mp87zDafLq4QdgAOLKi9tIlhKgWbi5OzBwXUeZypRS9wgNrfeDCutdj4krScjCgwNULWg0tPT/pGCQerdp+T++Ed/pAwpFqSaYQom6TQFCbPAPNSKZdx5XuYdxysPmsaq5gzQyI3w8r/8/+NAoh6jwJBLXtzu/hhtdLzwtqCf7NYOd3EP9n5faXFA0HfjFNUvf8AGf3VVtShRB1k12BQCkVqJRaqpQ6ZP0MsLHOIKXUjhJ/2UqpMdZlnymljpVYFmFPeuoMpaDXA6Yp6Tu9YHZ/2PZFxbbd/L6pfP7LQlPpvOq/lzSpQogrn705gqnAcq11a2C5dboUrfVKrXWE1joCGAxkAktKrDK5cLnWeoed6ak7+v0NnjoA171qxh366W9mfKOLyU4xAaPjTdCwE/R9GPb/BKd31UyahRBXJHsDwWhgjvX7HGBMOeuPBRZrrTPtPK5j8GkIvR+AuxeDX1P45e+Qn1v2+tu+gNw06POwme7zMHj4SV2BEOKi7A0EDbTWp63fzwDljSY1HvjmvHnTlFK7lFKvK6XcbW3k8Ny84PoZZviJDbOK58cfhI3vwP5fTF3C5veheX9obB2lsZ4/9HsUDi6GPz6EK7DzoBDi0iu3H4FSahnQ0MaiZ0tOaK21UqrMO41SKhToDPxeYvYzmADiBnwATAFeLmP7+4H7AZo1a2Zrlbqt7QhoPwrWvAZtr4Pd82DDW+adCCWNOK9OoM/DcGITLHratEIa9bZprZR41ASW8IEVe5+CEHVRepx5H3m/Rx166He7hphQSv0JXKO1Pm290a/SWtt8/6JS6nGgo9b6/jKWXwM8rbW+obzj1okhJqoiNRbe7gW56YCGrnfAwH9AVqLJHeRlQo+7wem8jJ7FYt6KtvR5k7vQGrKTzbLIe+GGmTV9JkJcHtb+D5a/bIpfm/er7dRccmUNMWFv0dBPwETr94nAjxdZ93bOKxayBg+UGZpvDLDHzvTUbb6NYOT/oFlfmLQIbpoNgeHQuAdE3A49770wCICZ1+ch+OsKU3TUYbQZNrvTWPPCnKTjNX8uV6KCPIj6BHKliqvOiN1hPh28J7+9gWA6MEwpdQgYap1GKRWplPqocCWlVBjQFFh93vZfKaV2A7uBYOA/dqan7us6Du5ZDGH9K79tw84w/isYNQt6TIJhL5umpmteLV4nNxN++GvFm6s6kr0L4ZcnYMuHtZ0SUV0KA8Hh5bWajNpmVyDQWidorYdorVtrrYdqrROt86O01veVWC9aa91Ya205b/vBWuvOWutOWusJWut0e9IjKsmvsclF7PgGzh2GgnyYdw/sngs/P152c9W8LPhuAnx7p2NVQO/+3nxu+RgsBbWbliuF1qYc/nKUmQgpJ8AzCGK3m2kHJT2LHd1VT4CLO6z6P/j1CdPCaMgLENAc5t1t3rVQUl4WfHM77P/Z9GA+uqrsfRfkm57OdUFGAhxZDiHtIfk4HF5W2ym6vGQlw8k/Lpy/d74Zbr2q42ZdSrHbzWefhwENR1fWanJqkwQCR+dd3/RV2PODqS8Y8BQMeBJu+xyykuCHe4qffnMz4Zvx5uZ/45vg28S8lrOsXMHif8CsbrZvEJez/Bw4srL0ee1bAJZ8GPMueDc0zXEvBYvF1EVcaX7/J3wyAtLOlJ6/e5753S72wFBbCgNB5D2mv40D1xNIIBDQ7zHwaWRaHA3+l5nXsLOpmD62Bt7sCjM7wv/awtHVMGa2qWO4+mkzDIatp+OEI7D1M9AWWPCAeRezLXH74ePhsO4NyE69RCdYSVs/gy/GmGaFhXbPM7mBRt3MuR9edmmecn9+1ATP5BPVv+/qsOqVC4sE0+NNsZkuMPUohXIzim+u0evtP/aWj01T6OpyegcEhJvm1OEDLwz+DkQCgTD/Ef6+G258w4xzVKjbBLj2/6BJJLQYCF1ugzu+My2UACLuNIPj2coVrPgPuHjA2E8h8Zhpunq+/BxTMR27A5a9AG90ghXTqv5EnJFQPcHkoLWry29TIeWUuSmf2Aidx5rfp8ck0+Z8y8f2H6uk2B2w/UtIOQlfjr38yqzT42DdTFMkeOCX4vnbPjPv7PZpBHvmFc8/vAzys8GvGRxfb99NNjsFFk2GJc9VfR/ni90JjSLM91ZDIPVU5Qd5rCMkEAjDuYy+hX0fgVs/M0UiI/8Hba4tXubiBlf/w2SxD/5WPD92hykb7vMQdLrZ7GPLRxfmHFb9F87uhtvmmKatYQNMC6bCl/VURtpZeLe3efdzRW84qbHw5+LS83IzIXodtLvBFIn99LfiSuLOY82nb6hZvv1L201Jq1qRvPwlqBdg3l2ddAy+vQPysqu2r0th02wTvP2awrKXTB1QQZ4JiC0HQ+/7TQ6xsF5o/y9QLxD6PwZpp+3LQUWvMzmOmC3VkxPLSDAVxY26mekWg8yngxYPSSAQ9uk63mSvf30Ktn9lbgzLXzY3tP6PmXUG/wtC2sGCh0wLpfxck8Vf/yZ0v8v0lG7cA8Z9Cf7NL7w5l8diLX7KiIeYP+DYea2Uc9IuLLsG+OlRU+dR8sYSvQ4Kcky58fCXzY1h9WvQtDcEhBWv1/M+0ylv7/zS+8zNMEVpcyeaivWKOrLSHOvqydDuerjpPZMLWfDA5VFckZ1ignmH0TBiOiQcgh1fmkEN005D7weh0y1m3T3zzTU++Du0vd4Uu4D5bavqyAqTwwRTTGev09b6gdAI8xnQHIJamQYBl9KhZbDvYt2taocEAmEfZ1e4+QNz4//xYXi9o/nPdNWTpgIOzLuYx34KXsGw8EFzo5x3j3myvLbEgHhKmRvH0VVl1ynYsvEt0+JjxCumInft/4qX5eeYSsx3+0LGueL5xzcW51BK9pk4tARcPU3Hux73mJtYfhZ0vrX0McOugqDWsHVO6fl7F5qinX0LYc6NpY9ZFosFlr1ofo/Ie828TreY1lv7FpqK/JKOrjZFR6mx5e8bzG/w8+OmeO7Arxe2BCt0dq8pCrNly8eQk2oaErQbad6tvWq6GeYkIBxaDTPFhE17m/RGr4WcFGh/AwS3Bq/6pnioqo6shPCrzXXZNdf+4FjYfyC0a/G8lkNMXcalyoVpbQaOnH+/7QeTWiSBQNivaS94cB3c8T0EtoT6HaDXX0uv06ADPLQB7pxnXryTEQ83vQ/uPqXXa3udKVeuaCuTmK0mB9J+lGn91O9vpoL75BazfOU0OLvHPNEunmLmaQ0r/g3eDaDFNaYoqiDPzD+81NxwXD1Mj+wx75pK9C63lT6uUiY3E/OHqfAutPUzCG4Lt30BZ3bDR0NLL7dl3wJTcTnoWXPcQv0fh0bdTV1FYX1BSgx8P8mk8+vbTG6nPDu/Mena+I4pbprRypS1l7yZHl0FH1wD7/S+8Ik7Lws2vWtulKFdzbkPe8nkBGK3m9+9sEd7p1vM7712pnkFa4tBZv2w/tbinTJu4Be7sSefgMQjZl+dbzW5kdM7i7f7/VlTVFUZp3dAYAszMGOh1sNM0D+0pKyt7HNqq3lIyM8u/bByGZBAIKqHUtBmuOn1/PDG0q/eLLlO62Ew6Rd4Jgaa971wneb9wN0P/lxU/jGPbzQ3Np9Q01taKXPTrhdg/qNFr4f11l7UA6eYisw/fzPFDMfXw4CnzQuA0s+aYoyEw6Z8u/Ww4mP4NTGV6IW5m5Ii7gAn1+Jcwdm9JjD0mAgdRsHEX8yNenZ/U3RmK3dw7rBZ1qDzhcHGydk0081MNJXpBXkmJ1WQC9e9Zt4+N+8eU1ZfloJ80yKrUTfzm9/zO0RMME/yi6eYG+mpbaYlUFArE7B/uBcWPmwC2eHl5iabEW9yAyWvU9vrwd3X/A6FOt5keqsfXwethxYHtub9TWXs+f1Kzh2CH/8G0xqWXWRyxNq+v+VgUzTl5Fpcb7PxHdj4tqnErkyRUeyO4mKhQi0GgW9j2PppxfdTUmbixa/Fnvng7AYdb4aoTy+rlmHljj4qxCXhUsaI486u5kb852+m0tXWiJAWC6x/3bQw8m9mhs2oZ305nru36SC0chrEbjPl+sOnmf+A+xaaISK8gk0xTI+JoJxNINk2p7jCsNWwC49pi1ewKSbZ9S0MfdEEBGc36DLeLG/aEx7ZbIpQoj4xrx69+mmTPhc3Exi+usXcOMd9bvtcQ7uYyvYNs0yF+MnNcMvHpuLa2dUUNSx6Cob9Gzx8L9x+30JT8TzsCxOcm/UxxTf1/M0NNCvRBEbPQJgw37zidPV0897rkpX2LYeYm3lJN39gzqFkkPSub3JUR1dBuxuL54ddZT6PrzfjY2WcM9di/8/m34K7j8nZtbvhwt/hyApzjULaWh8mhpubfquhprir3Q0mUP3yhGnhFhBmcjFLnzfzx35aujVcRoJ5Mj8/1+rsYnJ5q/5rWroFhl/4e2ptHhhyM8w+LflwbK1pRRWzBbrebup3zmexmGvRcggM/7dZf81rMOotszw3w/xmxzeY+rP0OPObtRxs/ryCLtxnNZJAIC4/ba8zT++ntppip9RYUwySHm9uOrrA/GfseLN5Yj7/BtjrryYnkBFvnoALh9ke9TZ8PBTSYs33wmAUcad5okyNNcU6Ac0rntYeE81/8N3fm4DQflTp/7RewTByBvS6H5b+yzzZb//C3LjXzTRlxRN/McUUZblmqjnGod9NHUJh66XIu80T9vo3TD1H4+4mmPV92ARGrWHd6xDcxtwsCykFw/8DTi5mW68Q82pT31CzfPBzZv3EI+YG7NPQNAEteTMFc/M+v2gPTEV64jGTQywU0s4M5RC9zhTHfT7G3IyvtubKTmyAuXeZ+oWSOSNLgan8b3Nd8fG73Ap//gpfjzO5mJveM0/j711lmiOPfhvm3WtapIEpTmo3snif51cUl9T9Llj9qnkwGPpi8fxja02O5c/FkBpz4XahEdBmhCmG6zTW5IZKivnD5IiGvGBymT3uNpXvvR8y9VtrZ0LmOfMg0biHadZ6cDHs/Nq8cvavKyGkzYXHrSZ2DUNdWxx2GGpHkZUMr7U0Y8RfPdlU9iYeg45jTFFLbrq54Xa/68KbU6H9v5iy2MKbZqE1M8yNZcKC4iazSdGmAhug79/g2mkVT6vFArMiIDPBpGviz+aJuCyHlpoimcQjgDI9uDuMKv84J7fAru/MDbxkPYLW5in76CpTiXwqyvT4Hvux6Rn+9W2mA2DJ4puS2+75ARp2uaQ3mSLf/cU87bq4mzqbO74rHvrZYjE38oJck4sqzBWc2gYfDoKbPywOEHlZ8Fprc+3/usJURoPJJfxwr8lhefiZ8/79n6au4oE1xfUY8+835z35SOk6gkLf3GFu3E/sMzm3TbNNPY2rpwm0bYabyu/CodNCu5icaX4OvDfADAf/8KbS7/lYPMUUB00+bB5c0s7AmxGmhZq2mEYJVz1hRhYuvL6WAvMw9NWtUL+9GXHY1ujClVDWMNRora+4vx49emhRx312o9ZvRWr9zR1av+iv9Z+/X9rjzRmt9Qu+Wh9ZWfltV79qtn0zQmuLpfz187K13viu1ru+r/yxynMySuvXO2v9YoDW/2uv9cyOWufnVv9xqmLTe+Z3eqWF1qe2X7h8zwKzvOTvsvo1My/tbOl1j662vY9FU8y1TD5ppnd8a7bfu7D0MVZMKzudB5eadXb/YF3fz/w7zM0s/xyPbzLrL5pSPK+gQOvX2ph9lLThba0/v0nro2suvs9tX5j0/PFh+ccvBxClbdxTJUcgLk+b3oPfrK18RrwCfR68tMeLXm/K4W/7wjwFVkbqaXiru8n2X+p0VkR2Cvz0mClOun7GhWXhtSU93hSNXfVE8VN8SRYLvNffPAk/tN5U4n9/t3nCfqiKTU8tBaYllLMrTPgBZvczdQj3LjXzbG5jgVldwdndVOg26gZ3LbTdAMKWX582xT63f2PqM05sgs+uL67bqSytzZAnMVtNbsmvceX3YVVWjkACgbg8JZ+At3ua8vuR/yu7COhykZUEHv6XTzq1Nq8iDWl3+aSpIvbMN6PeOrmYilgwgaNkeX1lFRYZ+TUz9UYPrrUdiEpaM8M0MQ5uY+qZPAMrfrzsVBPQkk+Yupp6gaZ+YPKRqr8WNvGYCWLhV8Pt31b5mpYVCKSyWFye/JvBk/vNf6Qr4UZW2GrpcqGUKVe+0nQYAwP2mndx+zc3FffNr7Jvnx1vMjf2+P0mh1ReEADzno7MRJPDq0wQAFMH8OA602Hx0DLT6qnr7fa9Gzww3PQzWfIsxESZFmnVSHIEQoi679Q2c0Me8NSV8WBhi6XADOluq/9NBUmOQAjhuBp3N39XMidnu4LARXd9SfYqhBDiiiGBQAghHJwEAiGEcHASCIQQwsFJIBBCCAcngUAIIRycBAIhhHBwEgiEEMLBXZE9i5VS8cDxKm4eDFTgRbJ1jiOetyOeMzjmecs5V0xzrXXI+TOvyEBgD6VUlK0u1nWdI563I54zOOZ5yznbR4qGhBDCwUkgEEIIB+eIgeCD2k5ALXHE83bEcwbHPG85Zzs4XB2BEEKI0hwxRyCEEKIEhwoESqkRSqk/lVKHlVJTazs9l4JSqqlSaqVSap9Saq9S6nHr/ECl1FKl1CHr52X2Si37KaWclVLblVK/WKfDlVKbrdf7O6VUJV9GfPlTSvkrpeYppQ4opfYrpfrW9WutlHrC+m97j1LqG6WUR1281kqpT5RScUqpPSXm2by2yphlPf9dSqlKvXzBYQKBUsoZeAe4DugA3K6U6lC7qbok8oGntNYdgD7AI9bznAos11q3BpZbp+uax4H9JaZfAV7XWrcCkoB7ayVVl9abwG9a63ZAV8z519lrrZRqDDwGRGqtOwHOwHjq5rX+DBhx3ryyru11QGvr3/3A7MocyGECAdALOKy1Pqq1zgW+BUbXcpqqndb6tNZ6m/V7GubG0BhzrnOsq80BxtRKAi8RpVQTYCTwkXVaAYOBedZV6uI5+wFXAx8DaK1ztdbJ1PFrjXmzYj2llAvgCZymDl5rrfUaIPG82WVd29HA59rYBPgrpUIreixHCgSNgZMlpmOs8+ospVQY0A3YDDTQWp+2LjoDNKitdF0ibwD/ACzW6SAgWWudb52ui9c7HIgHPrUWiX2klPKiDl9rrfUpYAZwAhMAUoCt1P1rXaisa2vX/c2RAoFDUUp5Az8Af9dap5Zcpk1TsTrTXEwpdQMQp7XeWttpqWEuQHdgtta6G5DBecVAdfBaB2CefsOBRoAXFxafOITqvLaOFAhOAU1LTDexzqtzlFKumCDwldZ6vnX22cKsovUzrrbSdwn0B0YppaIxRX6DMWXn/tbiA6ib1zsGiNFab7ZOz8MEhrp8rYcCx7TW8VrrPGA+5vrX9WtdqKxra9f9zZECwRagtbV1gRumgumnWk5TtbOWjX8M7Ndazyyx6CdgovX7RODHmk7bpaK1fkZr3URrHYa5riu01ncCK4Gx1tXq1DkDaK3PACeVUm2ts4YA+6jD1xpTJNRHKeVp/bdeeM51+lqXUNa1/Qm4y9p6qA+QUqIIqXxaa4f5A64HDgJHgGdrOz2X6ByvwmQXdwE7rH/XY8rMlwOHgGVAYG2n9RKd/zXAL9bvLYA/gMPA94B7bafvEpxvBBBlvd4LgYC6fq2Bl4ADwB7gC8C9Ll5r4BtMPUgeJvd3b1nXFlCYVpFHgN2YVlUVPpb0LBZCCAfnSEVDQgghbJBAIIQQDk4CgRBCODgJBEII4eAkEAghhIOTQCCEEA5OAoEQQjg4CQRCCOHg/h8J+pmYoGGzsgAAAABJRU5ErkJggg==\n",
anastasiaslobodyanik's avatar
anastasiaslobodyanik committed
274
275
276
277
278
279
280
281
282
283
284
285
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "sns.lineplot(data= train_losses, label='training loss')\n",
286
    "sns.lineplot(data= val_losses, label='validation loss')"
anastasiaslobodyanik's avatar
anastasiaslobodyanik committed
287
288
289
290
   ]
  },
  {
   "cell_type": "code",
291
   "execution_count": 371,
anastasiaslobodyanik's avatar
anastasiaslobodyanik committed
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
   "id": "1cb5e2b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "loss_func = nn.NLLLoss()\n",
    "\n",
    "oracle.eval()\n",
    "with torch.no_grad(): \n",
    "    y_val = oracle(test_dataset.numerical_data,\n",
    "                   test_dataset.categorical_data)\n",
    "    \n",
    "loss = loss_func(y_val, test_dataset.labels)\n",
    "out = torch.argmax(y_val, dim=1)"
   ]
  },
  {
   "cell_type": "code",
309
   "execution_count": 372,
anastasiaslobodyanik's avatar
anastasiaslobodyanik committed
310
311
312
313
314
315
316
   "id": "18b5e4f3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
317
318
319
320
      "Loss: -0.7749953866004944\n",
      "Accuracy: 0.774537802495115\n",
      "Recall: 0.8294385432473445\n",
      "Precision: 0.7444837918823209\n"
anastasiaslobodyanik's avatar
anastasiaslobodyanik committed
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
     ]
    }
   ],
   "source": [
    "print(\"Loss:\", loss.item())\n",
    "acc = accuracy_score(test_dataset.labels, out)\n",
    "print(\"Accuracy:\", acc)\n",
    "recall = recall_score(test_dataset.labels, out)\n",
    "print(\"Recall:\", recall)\n",
    "precision = precision_score(test_dataset.labels, out)\n",
    "print(\"Precision:\", precision)"
   ]
  },
  {
   "cell_type": "code",
336
   "execution_count": 373,
anastasiaslobodyanik's avatar
anastasiaslobodyanik committed
337
338
339
340
341
342
343
344
345
   "id": "cde855e6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<AxesSubplot:>"
      ]
     },
346
     "execution_count": 373,
anastasiaslobodyanik's avatar
anastasiaslobodyanik committed
347
348
349
350
351
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
352
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWcAAAD4CAYAAAAw/yevAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAAcSUlEQVR4nO3deZxU1bnu8d9jA8ogg4ITIIOiEbwGFQ3OJA4MMYKa65V4FZUTHDDRHGPikJPBKZ44xGOOR4OBoNeBoEblGFCROMREjKhcBRFtELUZI6iIGqC73vNHbbDUHqqhumuzfb75rA9V7961h0ReV9699lqKCMzMLF22KvcFmJnZFzk5m5mlkJOzmVkKOTmbmaWQk7OZWQq1aOoTfPKnGz0cxL7g2LOmlvsSLIVmVD2mzT3G+ncXFp1zWnbuvdnnaypNnpzNzJpVrqbcV1ASTs5mli2RK/cVlISTs5llS87J2cwsdcI9ZzOzFKqpLvcVlISTs5llix8ImpmlkMsaZmYp5AeCZmbp4weCZmZp5J6zmVkK1awv9xWUhJOzmWWLyxpmZinksoaZWQq552xmlkLuOZuZpU/k/EDQzCx93HM2M0sh15zNzFLIEx+ZmaVQRnrOXn3bzLIllyu+1UNSd0lPSHpV0lxJ5yfxn0taLGl20oYV/OYSSZWS5ksaXBAfksQqJV1czG2452xm2VK6yfargQsj4kVJ2wIvSJqebPt1RFxXuLOkvsDJQD9gF+BxSXskm28GjgaqgOclTYmIV+s7uZOzmWVLiUZrRMRSYGny+UNJ84Cu9fxkODApItYCb0qqBA5MtlVGxEIASZOSfetNzi5rmFmmRNQU3SSNkTSroI2p7ZiSegL7As8lofMkvSxpgqROSawr8E7Bz6qSWF3xejk5m1m2NKLmHBHjImJAQRv3+cNJagfcD1wQEauBW4DdgP7ke9bXN8VtuKxhZtlSwtEaklqST8x3RcQfASJiecH224CHk6+Lge4FP++WxKgnXif3nM0sW0o3WkPAeGBeRNxQEN+5YLfjgTnJ5ynAyZK2ltQL6AP8HXge6COpl6RW5B8aTmnoNtxzNrNsKd1ojUOAU4FXJM1OYpcCIyX1BwJYBJwFEBFzJU0m/6CvGhgbETUAks4DHgUqgAkRMbehkzs5m1m2lKisERHPAKpl09R6fnMVcFUt8an1/a42Ts5mli2e+MjMLIWcnM3MUigjc2s4OZtZtpTugWBZOTmbWba4rGFmlkIua5iZpZB7zmZmKeTkbGaWQhHlvoKScHI2s2yp9mgNM7P08QNBM7MUcs3ZzCyFXHM2M0sh95zNzFLIydnMLH2ipqbcl1ASXqbKzLKldMtUdZf0hKRXJc2VdH4Sv1bSa8nq2w9I6pjEe0r6RNLspN1acKz9Jb0iqVLSTckSWPVycjazbIlc8a1+1cCFEdEXGAiMldQXmA7sHRH7AK8DlxT8ZkFE9E/a2QXxW4Dvkl9XsA8wpKGTOzmbWbbkovhWj4hYGhEvJp8/BOYBXSPisYjY8KbLTPKradcpWRC2fUTMjIgA7gBGNHQbTs5mli2NKGtIGiNpVkEbU9shJfUE9gWe+9ymM4FpBd97SXpJ0lOSDktiXYGqgn2qkli9/EDQzLKlEQ8EI2IcMK6+fSS1A+4HLoiI1QXxy8iXPu5KQkuBXSNipaT9gQcl9Wvk1W/k5LwZlr23hp/cPYNVaz4B4MSD+nLK4fts3H7Hk7O5YcqzPHH56XRq15o/vfA6E//8EhHQZpuWXHbi4ezZtTMAf533Nr968BlyueD4gXtx5pH7leWerPROGD2CYSOHIcGf7p7GH8c/wOk/HMUhgw8ilwvef/d9fvWv17Jy+SrabtuGS266mB26dqGiooLJv72PRyc/Vu5b2LKUcCidpJbkE/NdEfHHgvjpwLHAkUmpgohYC6xNPr8gaQGwB7CYz5Y+uiWxejk5b4aKCnHh8IPZq1sXPvrnOkb++j4G7tGN3XbajmXvreHZ+VXs3Kndxv27btee8WNH0L7N1jwz7y2uuPcp7rzgRGpyOX75x79w69nfYscObTnl1/dzRL+e7LbTdmW8OyuFnnv2ZNjIYYw99nusX7+ea+68mpkznmPyrfcy8brbATj+zBGcesH/5cZLbmL4qON46423+MkZP6XDdh2Y+PR4ZjzwZ6rXZ2Myn2bRQC25WMmIivHAvIi4oSA+BPgRcEREfFwQ7wKsiogaSb3JP/hbGBGrJK2WNJB8WeQ04DcNnb/BmrOkr0j6cTL846bk816NvdEs6tK+LXt16wJA221a0XuHTqz44CMArnvor1xw7EDg0xEz/XvtRPs2WwOwT4+dWP5+ft85b6+ge+cOdNu+PS1bVDB43915cs6iZr0Xaxq77t6d12a/xtp/riVXk+Plma9w2NBD+HjNxr/TbNN6G5LOFxHQum0bAFq3bc2H739ITXU2xu02m9KN1jgEOBX4RsHwuGHAfwLbAtM/N2TucOBlSbOB+4CzI2JVsu1c4HdAJbCAz9apa1Vvz1nSj4GRwCTg70m4G3CPpEkRcU1DJ/iyWLxqNa8tfpf/1WNHnpjzJl06tN1YsqjNA8/N49C9ugOw4oOP2Klj243bduzYllfeWtHk12xNb9H8RYz+8Rm077gta/+5jq994wDmv/w6AGf+6HSO/vbRfLT6Iy486SIAHpz4EFf+/hdMfuEe2rRrwxXnXLUxcVuRStRzjohnKOxdfWpqHfvfT74EUtu2WcDejTl/Qz3n0cABEXFNRNyZtGuAA5NttSp8Ajr+kb815nq2SB+vXc8PJz7KRSMOoWIrMf7xFzl3yAF17v/8G4t58Ll5nH/sQc14lVYOb1e+w6T/msy/330N19x5NZVzF5CryffYJvxqIiMPPIUZD/yZEWccB8ABgwZQOXchJ+0/kjGDz+F7V55Hm3ZtynkLW5zI5YpuadZQcs4Bu9QS3znZVquIGBcRAyJiwOghB2/O9aXe+poaLpz4KMP224Mj9+lN1burWbxqNSdddy9Dr7iTFR+sYeQN9/Hu6vz/jX19yUp+MflJbjxzKB3bbgPADh3asiwpcQAsf/8jdujQttbz2ZZn2qRHOGfYWH7w7QtZ88EaqhZ+9lnQjAdmcNjQ/KirwScdwzPTngFgyaIlLHtnGd13797s17xFq6kpvqVYQw8ELwBmSHoDeCeJ7QrsDpzXhNe1RYgIfvGHJ+m1Q0dOHfRVAPrssj1PXH7Gxn2GXnEnd//gRDq1a83S9z7kwt8/wpXfOZIeO3TcuE+/7jvw9j/eZ/HK1ezQoS2PvlTJ1ace1dy3Y02k4/YdeX/l++ywSxcOHXoo5x33fbr22oXFby4B4ODBB/POgvxfrxWLV7Dvofvyyt/n0KlzR7rv1o2lby0t5+VveUpU1ii3epNzRDwiaQ/yZYwNg6YXA89HRLr/tdMMZr+5jIdnvU6fnbfjpOsmA/C9YV/jsL49at1/3GOzeP/jf3L1/U8D0GKrrbj7X79Ni4qtuPiEwzhn3MPkcsHwA7/C7h6pkRk/H/dvtO/Unurqam667Dd8tPojfnjdv9K9d3ciciyvWsGNl/wHAHf+x1386IaLuO3x3yLEbVePZ/V7qxs4g31GyssVxVJTP2z45E83ZuNfY1ZSx55V6zMV+5KbUfVYgxMCNeSjn55cdM5pe/mkzT5fU/E4ZzPLFq8haGaWQl+GmrOZ2ZYmMvLSjpOzmWWLe85mZinkmrOZWQq552xmlj7h5GxmlkJ+IGhmlkLuOZuZpZCTs5lZ+mRl/muvvm1m2ZKL4ls9JHWX9ISkVyXNlXR+Et9O0nRJbyR/dkriSlaLqpT0sqT9Co41Ktn/DUmjirkNJ2czy5YSJWfyK2tfGBF9gYHAWEl9gYuBGRHRB5iRfAcYSn7dwD7AGOAWyCdz4GfA18jP8PmzDQm9Pk7OZpYpUZ0rutV7nIilEfFi8vlDYB75qZOHA7cnu90OjEg+DwfuiLyZQEdJOwODgekRsSoi3gOmA0Maug8nZzPLllzxrXBJvaSNqe2QknoC+5JfPXvHiNiwAsIyYMfkc1c+XZQEoCqJ1RWvlx8ImlmmNOYllIgYB4yrbx9J7cgv3HpBRKyWPp0COiJCUpM8gXTP2cyypXQ1ZyS1JJ+Y74qIPybh5Um5guTPFUl8MVC44GO3JFZXvF5OzmaWLY0oa9RH+S7yeGBeRNxQsGkKsGHExSjgoYL4acmojYHAB0n541HgGEmdkgeBxySxermsYWaZUsK5NQ4BTgVekTQ7iV0KXANMljQaeAs4Kdk2FRgGVAIfA2cARMQqSVcAzyf7XR4Rqxo6uZOzmWVKVJcmOUfEM0BdawweWcv+AYyt41gTgAmNOb+Ts5llSzamc3ZyNrNsychc+07OZpYxTs5mZunjnrOZWQpFdbmvoDScnM0sU9xzNjNLISdnM7M0irqGJm9ZnJzNLFPcczYzS6HIuedsZpY6uRonZzOz1HFZw8wshVzWMDNLoWiSdUman5OzmWWKe85mZimUlQeCXqbKzDIlciq6NUTSBEkrJM0piP1B0uykLdqwSoqknpI+Kdh2a8Fv9pf0iqRKSTepcJXYOrjnbGaZEqV9Q3Ai8J/AHZ8eP/7Phs+Srgc+KNh/QUT0r+U4twDfBZ4jv5zVEGBafSd2z9nMMiVyxbcGjxXxNFDren9J7/ck4J76jpGs0N0+ImYmS1ndAYxo6NxOzmaWKblQ0U3SGEmzCtqYRpzqMGB5RLxREOsl6SVJT0k6LIl1BaoK9qlKYvVyWcPMMqUxZY2IGAeM28RTjeSzvealwK4RsVLS/sCDkvpt4rGdnM0sW5pjtIakFsAJwP4bYhGxFlibfH5B0gJgD2Ax0K3g592SWL1c1jCzTCnlaI16HAW8FhEbyxWSukiqSD73BvoACyNiKbBa0sCkTn0a8FBDJ3ByNrNMaUzNuSGS7gGeBfaUVCVpdLLpZL74IPBw4OVkaN19wNkRseFh4rnA74BKYAENjNQAlzXMLGNKOZQuIkbWET+9ltj9wP117D8L2Lsx53ZyNrNM8dwaZmYpVEy5Ykvg5GxmmZLzxEdmZunjnnORtj3+2qY+hW2BPlnyl3JfgmVUiefWKBv3nM0sU9xzNjNLoYwM1nByNrNsqcll4906J2czy5SMLL7t5Gxm2RK45mxmljq5jBSdnZzNLFNy7jmbmaWPyxpmZilU4+RsZpY+Hq1hZpZCWUnO2RitbWaWCFR0a4ikCZJWSJpTEPu5pMWSZidtWMG2SyRVSpovaXBBfEgSq5R0cTH34eRsZpmSU/GtCBOBIbXEfx0R/ZM2FUBSX/LLV/VLfvNfkiqSdQVvBoYCfYGRyb71clnDzDKllEPpIuJpST2L3H04MClZhftNSZXAgcm2yohYCCBpUrLvq/UdzD1nM8uUmkY0SWMkzSpoY4o8zXmSXk7KHp2SWFfgnYJ9qpJYXfF6OTmbWabkpKJbRIyLiAEFbVwRp7gF2A3oDywFrm+K+3BZw8wypanf3o6I5Rs+S7oNeDj5uhjoXrBrtyRGPfE6uedsZpmSa0TbFJJ2Lvh6PLBhJMcU4GRJW0vqBfQB/g48D/SR1EtSK/IPDac0dB73nM0sU0q5vquke4BBQGdJVcDPgEGS+pPvpC8CzgKIiLmSJpN/0FcNjI2ImuQ45wGPAhXAhIiY29C5nZzNLFNK+fp2RIysJTy+nv2vAq6qJT4VmNqYczs5m1mmlLLnXE5OzmaWKVl5fdvJ2cwyJSNz7Ts5m1m2uKxhZpZCLmuYmaVQjXvOZmbp456zmVkKOTmbmaWQR2uYmaWQR2uYmaWQyxpmZilUU+4LKBEnZzPLFJc1zMxSyGUNM7MU8mgNM7MUymUkPXuZKjPLlMasvt2QZHXtFZLmFMSulfRasvr2A5I6JvGekj6RNDtptxb8Zn9Jr0iqlHSTpAYr407OZpYpJV5DcCIw5HOx6cDeEbEP8DpwScG2BRHRP2lnF8RvAb5Lfl3BPrUc8wucnM0sU3IqvjUkIp4GVn0u9lhEVCdfZ5JfTbtOyYKw7SNiZkQEcAcwoqFzOzmbWabkiKKbpDGSZhW0MY083ZnAtILvvSS9JOkpSYclsa5AVcE+VUmsXn4gaGaZ0pjHgRExDhi3KeeRdBn5VbbvSkJLgV0jYqWk/YEHJfXblGODk7OZZUxzjHOWdDpwLHBkUqogItYCa5PPL0haAOwBLOazpY9uSaxeLmuYWabUEEW3TSFpCPAj4LiI+Lgg3kVSRfK5N/kHfwsjYimwWtLAZJTGacBDDZ3HPWczy5RS9pwl3QMMAjpLqgJ+Rn50xtbA9GRE3MxkZMbhwOWS1ieXcXZEbHiYeC75kR+tydeoC+vUtXJyNrNMKeVLKBExspbw+Dr2vR+4v45ts4C9G3NuJ2czy5RsvB/o5GxmGeOJj8zMUmhTH/SljZOzmWVKViY+cnIuocrXZ/LhmjXU1OSorq5m4EHDABh77hmcc87p1NTUMG3aDC6+5CqOOvIwrrrqUlq1asm6deu5+OIreeLJv5b5DqwUli7/B5decR0r33sPIb49fCinnjSCC//tlyx6O/+i2Idr1rBtu3bcf/vNvPLqfH7+7zcBEATnnnkKRx1xCGvXrmPU2ItYt349NdU1HP31QznvX04t561tEbKRmp2cS+6oo/83K1e+t/H7oCMO5rhvDWa//Y9m3bp1dOmyPQDvrlzFiONPZ+nS5fTrtydTH76LHr0GlOuyrYRaVFRw0fe+S989d+ejjz7mpNHf5+AD9uX6Kz6dH+fa39xGu7ZtANi9dw/+MP4mWrSo4B/vruLEUecy6JCBtGrVkgk3XUObNq1ZX13Naef8kMMGDuCre+9VrlvbImSl5+yXUJrYWWedxq+uvZl169YB8I9/rARg9uy5LF26HIC5c+fTuvU2tGrVqmzXaaXTpfN29N1zdwDatm1D7x7dWZ787w4QETzy56cZdvQgAFpvsw0tWlQAsHbdOkhmk5REmzatAaiurqa6upoiZpr80ivxrHRl4+RcQhHBtKn38NzMafzL6FMA6NOnN4ceeiB/e+a/+fPj9zFg/69+4XcnnPBNXnppzsYEbtmxeOly5r2xgH367bkx9sL/n8P2nTrRo/unc9+8PPc1hp9yFsefdg4/vei8jcm6pqaGE0eN5fBjR3LQAfuyT7+vNPs9bGmiEf9Js00ua0g6IyJ+X8e2McAYAFV0YKut2m7qabYoR3z9eJYsWUaXLtvzyLRJzJ9fSYsWFXTq1JGDD/0WBwzozz1330qfPQ/a+Ju+fffgl1ddytBvfqeMV25N4eOPP+EHl13Jj79/Fu3afvp3YOr0Jxl29BGf2Xeffl/hobt+y4JFb3PZlddz2MAD2HrrVlRUVHD/7Tez+sM1nH/JFbyxcBF9evds5jvZsmRltMbm9Jx/UdeGiBgXEQMiYsCXJTEDLFmyDMiXLh56aBoHHNCfxVVLefDB/Juaz8+aTS6Xo3Pn7QDo2nVn7rt3PGeceT4LF75Vtuu20ltfXc0Fl13JN4/5OkcPOmRjvLq6hsef+htDjjy81t/t1nNX2rRuzRsLF30m3n7bdhy43z48M3NWU152JnwpyhrJMiy1tVeAHZvpGrcIbdq0pl27ths/H33UEcydO5+HpjzKoEEHA/kSR6tWrXj33VV06NCeKQ/dwaWXXc3fnvVfuCyJCH76yxvp3aM7o04+4TPbZs56id49urHTDl02xqqWLKO6Or9o0pJly3nzrXfouvOOrHrvfVZ/uAaAf65dy7PPv0SvHt2b70a2ULmIoluaNVTW2BEYDLz3ubiAvzXJFW2hdtyxC/fdm3/lvkWLCiZNepBHH3uSli1b8rvbrmf2SzNYt249Z46+AMgPr9t9t5785LIf8JPLfgDA0GEjNz4wtC3XSy/P5b8fmUGf3Xpy4qixAJx/1igOP/hApj3+FEOPGvSZ/V98eS7j/99kWrRowVZbiZ/8cCydOnZgfuWbXHblddTkckQuGPyNwxh0yNfKcEdblnSn3OIp6vm3h6TxwO8j4platt0dEQ0WSlu06pqV/66shD5Z8pdyX4KlUMvOvTd7OMp3ehxfdM65+60HUjv8pd6ec0SMrmebn2CZWeqkfRRGsfwSipllSrWTs5lZ+mSl5+yXUMwsU0o5lE7SBEkrJM0piG0nabqkN5I/OyVxSbpJUmUyqm2/gt+MSvZ/Q9KoYu7DydnMMiUiim5FmAgM+VzsYmBGRPQBZiTfAYaSXzewD/mX8G6BfDInv7zV14ADgZ9tSOj1cXI2s0zJEUW3hkTE08Cqz4WHA7cnn28HRhTE74i8mUBHSTuTH448PSJWRcR7wHS+mPC/wDVnM8uUxry+XTjVRGJcRIxr4Gc7JitqAyzj0xfyugLvFOxXlcTqitfLydnMMqUxU4YmibihZFzf70NSkzyBdFnDzDKlxDXn2ixPyhUkf65I4ouBwvfruyWxuuL1cnI2s0xphomPpgAbRlyMAh4qiJ+WjNoYCHyQlD8eBY6R1Cl5EHhMEquXyxpmlimlHOcs6R5gENBZUhX5URfXAJMljQbeAk5Kdp8KDAMqgY+BMwAiYpWkK4Dnk/0uj4jPP2T8AidnM8uUUi5TFREj69h0ZC37BjC2juNMACY05txOzmaWKTWR9pmai+PkbGaZkpXXt52czSxT0j6JfrGcnM0sU7KRmp2czSxjSvlAsJycnM0sU5yczcxSyKM1zMxSyKM1zMxSaDPmzEgVJ2czyxTXnM3MUsg9ZzOzFKrZnPnmUsTJ2cwyxW8ImpmlkEdrmJmlkHvOZmYplJWes5epMrNMyUUU3eojaU9JswvaakkXSPq5pMUF8WEFv7lEUqWk+ZIGb859uOdsZplSqte3I2I+0B9AUgX5RVkfIL/81K8j4rrC/SX1BU4G+gG7AI9L2iMiajbl/O45m1mmRCP+0whHAgsi4q169hkOTIqItRHxJvm1BA/c1PtwcjazTInIFd0kjZE0q6CNqeOwJwP3FHw/T9LLkiYkK2oDdAXeKdinKoltEidnM8uUHFF0i4hxETGgoI37/PEktQKOA+5NQrcAu5EveSwFrm+K+3DN2cwypQle3x4KvBgRy5PjL9+wQdJtwMPJ18VA94LfdUtim8Q9ZzPLlMb0nIs0koKShqSdC7YdD8xJPk8BTpa0taReQB/g75t6H+45m1mm1ORKN7eGpLbA0cBZBeFfSepPfrnCRRu2RcRcSZOBV4FqYOymjtQAUFPP4NSiVddsjAi3kvpkyV/KfQmWQi0799bmHmOnjnsVnXOWvT9vs8/XVNxzNrNM8ZShZmYp5Mn2zcxSyD1nM7MUKuUDwXJycjazTHFZw8wshVzWMDNLIU+2b2aWQlmZbN/J2cwyxT1nM7MUypVosv1yc3I2s0zxA0EzsxRycjYzS6FspOZmmJXOPiVpTG0rLdiXm/+5sNp4sv3mVdf6ZPbl5n8u7AucnM3MUsjJ2cwshZycm5frilYb/3NhX+AHgmZmKeSes5lZCjk5m5mlkJNzM5E0RNJ8SZWSLi739Vj5SZogaYWkOeW+FksfJ+dmIKkCuBkYCvQFRkrqW96rshSYCAwp90VYOjk5N48DgcqIWBgR64BJwPAyX5OVWUQ8Dawq93VYOjk5N4+uwDsF36uSmJlZrZyczcxSyMm5eSwGuhd875bEzMxq5eTcPJ4H+kjqJakVcDIwpczXZGYp5uTcDCKiGjgPeBSYB0yOiLnlvSorN0n3AM8Ce0qqkjS63Ndk6eHXt83MUsg9ZzOzFHJyNjNLISdnM7MUcnI2M0shJ2czsxRycjYzSyEnZzOzFPofPTbQGwNSE5AAAAAASUVORK5CYII=\n",
anastasiaslobodyanik's avatar
anastasiaslobodyanik committed
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
      "text/plain": [
       "<Figure size 432x288 with 2 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "cf_matrix = confusion_matrix(test_dataset.labels, out, normalize=None)\n",
    "sns.heatmap(cf_matrix, annot=True, fmt='d')"
   ]
  },
  {
   "cell_type": "code",
370
   "execution_count": 374,
anastasiaslobodyanik's avatar
anastasiaslobodyanik committed
371
372
   "id": "7a0d2c30",
   "metadata": {},
373
   "outputs": [],
anastasiaslobodyanik's avatar
anastasiaslobodyanik committed
374
   "source": [
375
376
377
    "torch.save({'model_state_dict': oracle.state_dict(),\n",
    "           'seed': torch.seed()},\n",
    "           'nn_cls.pt')"
anastasiaslobodyanik's avatar
anastasiaslobodyanik committed
378
379
380
381
382
383
384
385
386
387
388
389
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3342901f",
   "metadata": {},
   "source": [
    "# Region Vorhersage"
   ]
  },
  {
   "cell_type": "code",
390
   "execution_count": 469,
anastasiaslobodyanik's avatar
anastasiaslobodyanik committed
391
392
393
394
395
396
   "id": "d01750f1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
397
       "0.9722222222222222"
anastasiaslobodyanik's avatar
anastasiaslobodyanik committed
398
399
      ]
     },
400
     "execution_count": 469,
anastasiaslobodyanik's avatar
anastasiaslobodyanik committed
401
402
403
404
405
406
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# find region similarity\n",
407
    "region_data = pd.read_pickle('data/train_20_best_features_over_reg.pkl')\n",
anastasiaslobodyanik's avatar
anastasiaslobodyanik committed
408
409
    "X = region_data.drop(columns=['region'])\n",
    "y = region_data.region\n",
410
411
    "X_train, X_test, y_train, y_test = train_test_split(X, y, random_state = 1)\n",
    "dtree_model = DecisionTreeClassifier(max_depth = 12).fit(X_train, y_train)\n",
anastasiaslobodyanik's avatar
anastasiaslobodyanik committed
412
413
414
415
416
417
    "dtree_predictions = dtree_model.predict(X_test)\n",
    "accuracy_score(y_test, dtree_predictions)"
   ]
  },
  {
   "cell_type": "code",
418
419
   "execution_count": 438,
   "id": "d7f2c3ae-45f4-4e34-9a83-b9176d9bb832",
anastasiaslobodyanik's avatar
anastasiaslobodyanik committed
420
421
422
   "metadata": {},
   "outputs": [],
   "source": [
423
    "pickle.dump(dtree_model, open(\"dtree_model.pkl\", \"wb\"))"
anastasiaslobodyanik's avatar
anastasiaslobodyanik committed
424
425
426
427
428
429
430
431
432
433
434
435
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fcae5717",
   "metadata": {},
   "source": [
    "# Windturbineausfall Vorhersage nach Regionähnlichkeit"
   ]
  },
  {
   "cell_type": "code",
436
   "execution_count": 488,
anastasiaslobodyanik's avatar
anastasiaslobodyanik committed
437
438
439
440
441
   "id": "9a1b7e72",
   "metadata": {},
   "outputs": [],
   "source": [
    "# load newdata\n",
442
    "test_df = pd.read_pickle('data/test_20_best_features_over_reg.pkl')\n",
anastasiaslobodyanik's avatar
anastasiaslobodyanik committed
443
444
445
446
447
    "test_df = test_df.loc[test_df.label.notna()]"
   ]
  },
  {
   "cell_type": "code",
448
   "execution_count": 489,
449
   "id": "0f133720-bbda-445f-b172-c68127012f18",
anastasiaslobodyanik's avatar
anastasiaslobodyanik committed
450
   "metadata": {},
451
   "outputs": [],
anastasiaslobodyanik's avatar
anastasiaslobodyanik committed
452
   "source": [
453
454
455
456
457
458
459
460
461
462
463
    "# load model\n",
    "data = torch.load('nn_cls.pt')\n",
    "torch.manual_seed(data['seed'])\n",
    "model = Classifier(num_numerical_columns = test_df.shape[1] - 2, \n",
    "                    output_size = 2,\n",
    "                    layers_size = [test_df.shape[1] - 1, 40, 10, 5],\n",
    "                    embeddings_size = train_dataset.embeddings, \n",
    "                    p=0.4)\n",
    "model.load_state_dict(data['model_state_dict'])\n",
    "\n",
    "failure_predict = pickle.load(open(\"dtree_model.pkl\", \"rb\"))"
anastasiaslobodyanik's avatar
anastasiaslobodyanik committed
464
465
466
467
   ]
  },
  {
   "cell_type": "code",
468
   "execution_count": 490,
anastasiaslobodyanik's avatar
anastasiaslobodyanik committed
469
470
471
472
   "id": "1e0b3cf6",
   "metadata": {},
   "outputs": [],
   "source": [
473
474
475
    "# predict region by similarity\n",
    "similar_region = failure_predict.predict(test_df.drop(columns=['region']))\n",
    "test_df.region = similar_region\n",
anastasiaslobodyanik's avatar
anastasiaslobodyanik committed
476
477
478
479
480
    "test_dataset = ClassificationDataSet(dataframe=test_df)"
   ]
  },
  {
   "cell_type": "code",
481
   "execution_count": 491,
anastasiaslobodyanik's avatar
anastasiaslobodyanik committed
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
   "id": "8cbc99a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "loss_func = nn.NLLLoss()\n",
    "\n",
    "oracle.eval()\n",
    "with torch.no_grad(): \n",
    "    y_val = oracle(test_dataset.numerical_data,\n",
    "                   test_dataset.categorical_data)\n",
    "    \n",
    "loss = loss_func(y_val, test_dataset.labels)\n",
    "out = torch.argmax(y_val, dim=1)"
   ]
  },
  {
   "cell_type": "code",
499
   "execution_count": 492,
anastasiaslobodyanik's avatar
anastasiaslobodyanik committed
500
501
502
503
504
505
506
   "id": "ab9dd086",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
507
508
509
510
      "Loss: -0.5739957690238953\n",
      "Accuracy: 0.5741274201112079\n",
      "Recall: 0.45077089532053016\n",
      "Precision: 0.5921122757150471\n"
anastasiaslobodyanik's avatar
anastasiaslobodyanik committed
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
     ]
    }
   ],
   "source": [
    "print(\"Loss:\", loss.item())\n",
    "acc = accuracy_score(test_dataset.labels, out)\n",
    "print(\"Accuracy:\", acc)\n",
    "recall = recall_score(test_dataset.labels, out)\n",
    "print(\"Recall:\", recall)\n",
    "precision = precision_score(test_dataset.labels, out)\n",
    "print(\"Precision:\", precision)"
   ]
  },
  {
   "cell_type": "code",
526
   "execution_count": 493,
527
   "id": "1e4cc4e8",
anastasiaslobodyanik's avatar
anastasiaslobodyanik committed
528
529
530
531
532
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
533
       "<AxesSubplot:>"
anastasiaslobodyanik's avatar
anastasiaslobodyanik committed
534
535
      ]
     },
536
     "execution_count": 493,
anastasiaslobodyanik's avatar
anastasiaslobodyanik committed
537
538
     "metadata": {},
     "output_type": "execute_result"
539
540
541
542
543
544
545
546
547
548
549
550
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWcAAAD4CAYAAAAw/yevAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAAadUlEQVR4nO3de5xVddXH8c+aM8PITQYVucyMQoqVmqIgaGoiykXUsCyjTFFRrOQpH7QSSzQUTV8mSaWBgOANRBMlRQwFNU2uSqggDyNqMEACgxC3uZyznj/OdjzBXA5yZs5m+337+r04e+37y5k1v1n7t39j7o6IiIRLTrYvQERE9qTkLCISQkrOIiIhpOQsIhJCSs4iIiGU29AnqNy4SsNBZA9NO5ye7UuQEKqqKLV9Pcbe5Jy8Q760z+drKA2enEVEGlUinu0ryAglZxGJFk9k+woyQslZRKIloeQsIhI6rp6ziEgIxauyfQUZoeQsItGiB4IiIiGksoaISAjpgaCISPjogaCISBip5ywiEkLxymxfQUYoOYtItKisISISQipriIiEUER6zprPWUSiJZFIv9XDzD40s7fNbImZLQpiB5nZbDNbGfzbOoibmY0xsxIzW2pmJ6YcZ1Cw/UozG5TObSg5i0ikeKIy7ZamM929i7t3C5ZvAF5y987AS8EywDlA56ANAe6HZDIHbgZ6AN2Bmz9N6HVRchaRaMlgz7kWA4DJwefJwAUp8Yc8aR5QYGbtgb7AbHcvc/fNwGygX30nUXIWkWjxRNrNzIaY2aKUNmT3owF/M7PFKevauvu64PN6oG3wuRBYnbLvmiBWW7xOeiAoItGyFxMfufs4YFwdm5zm7qVmdigw28ze221/N7MG+VN86jmLSLTsRc+53kO5lwb/fgxMJ1kz/ndQriD49+Ng81KgOGX3oiBWW7xOSs4iEi0ZqjmbWXMza/npZ6AP8A4wA/h0xMUg4Jng8wzg0mDUxsnAlqD88QLQx8xaBw8C+wSxOqmsISLRkrnJ9tsC080MkrnyMXefZWYLgWlmNhj4CLgo2H4m0B8oAXYAlwO4e5mZ3QosDLYb6e5l9Z1cyVlEoiVDbwi6+yrg+Brim4Czaog7cE0tx5oITNyb8ys5i0ikuOsvoYiIhI/m1hARCaGIzK2h5Cwi0aKes4hICGVutEZWKTmLSLSorCEiEkIqa4iIhJCSs4hICKmsISISQnogKCISQipriIiEkMoaIiIhpJ6ziEgIKTmLiISQN8hfjWp0Ss4iEi1VGq0hIhI+eiAoIhJCqjmLiISQas4iIiGknrOISAgpOYuIhI/H9QdeRUTCRz1nEZEQ0lA6EZEQSmi0hohI+KisISISQnogKAB9LhxE82bNyMnJIRaLMW3iGO7+43heeX0+uXm5FBe257Ybh3Fgyxa8vWwFt9w5BgDH+ckVF3P2GafywUdruH7EHdXHXLN2HUOvvIRLvvetbN2W7IOiog5Mmngvh7Y9BHdn/PhH+cMfJ3DnHb/m3PN6U1FRwapVHzH4ymFs2bKVvLw87r/vTrp2PY5Ewhk2bASvvPoGAHl5eYy59zbOOOPrJBIJbhpxJ9Onz8zyHYZcRHrO5g38Nk3lxlXRKADVos+Fg3h8whhaF7Sqjr0+fzE9unYhNzfGPfdNAGDYTwazc9cu8nLzyM2NsWFjGRcO+glznnmU3NxY9b7xeJxeF1zClAdG06Fd20a/n8bStMPp2b6EBtOu3aG0b3coby15hxYtmrNg/iwu/M4VFBW2Z87c14nH49xx+40ADL/xdn78o0F07Xo8V141jDZtDubZvz7Cyaf0x925ecR1xGIxRtx8F2bGQQcVsGnT5izfYcOpqii1fT3GjruvTDvnNLt+fL3nM7MYsAgodffzzGwScAawJdjkMndfYmYG3Av0B3YE8TeDYwwCfh1sf5u7T67vvPX2nM3sK8AAoDAIlQIz3H15fft+UZ3ao2v15+OO+Qqz574GQNMDDqiOl1dUgO35dTFv0RKKC9tHOjFH3fr1H7N+/ccAbNu2nffeW0lhh3bMfvHV6m3mzX+TC799LgBf/epRzH35dQA2bNjElk+20q3r8SxctITLBg3kmK99AwB3j3RizpjMj9b4GbAcODAl9nN3f3K37c4BOgetB3A/0MPMDgJuBroBDiw2sxnuXuf/zJy6VprZL4GpgAELgmbAFDO7Ic0bizQzY8j//oqLrvgfnnhmz183pz/3N0475aTq5aXvvseAi6/mW5f+mBE/H/pfvWaA5196hf5nn9Hg1y2N4/DDi+hy/LHMX/DWf8Uvv2wgs16YC8DSpcs4/7w+xGIxOnYs5sQTv0ZRcQdatUrmgpG3/IIF82cxdcpYDj30kEa/h/1OwtNv9TCzIuBcYHwaZx4APORJ84ACM2sP9AVmu3tZkJBnA/3qO1idyRkYDJzk7r9190eC9luge7CuthsaYmaLzGzR+IempHFP+6+H7r+bJx78I/f/7lamPPUsi5a8Xb1u7OQpxGIxzutzZnXsuGO+wjOPjmXq+HsZ//A0yssrqtdVVlby8mvz6dMrur/yf5E0b96MaY8/wLDrb+Y//9lWHR9+w0+pqqrisceeAuDBSVMpXbOO+fOe557f/YY33lhEPB4nNzdGcXEH/jFvEd179GPevMXcdeeIbN3OfsMTibRbaq4K2pDdDvd74BfA7t3xUWa21MxGm1l+ECsEVqdssyaI1RavU31ljQTQAfhot3j7Gi62mruPA8ZB9GvObdskezIHty7grG98nbeXraBbl6/x9HOzefX1BYwfcwdWQ/niiI6H0axpU1au+pBjv3oUAH+ft4ivHnUEhxzUulHvQTIvNzeXJx5/gClTpvP0089Xxy+95CLO7X82vfteVB2Lx+Nc9/Nbqpf//sozrFy5ik2bNrN9+47qB4BP/uVZLr98YKPdw35rL0ZrpOaq3ZnZecDH7r7YzHqmrBoOrAeaBPv+Ehj5Oa+2VvX1nK8FXjKz581sXNBmAS+RrMN8oe3YuYvt23dUf/7Hgjfp/KWOvDZvERMfe4I/3Hnzf9WZ16xdT1VV8gtn7fp/88FHqyls/1lteebsl+nfu2ej3oM0jAfG/Y7l75Xw+3s/+77v26cn11//Yy749mXs3LmrOt606QE0a9YUgLPPOp2qqiqWL18JwLPPzabnGV8HoNeZp1XHpQ6ZK2ucCnzTzD4kWd7tZWaPuPu6oHRRDjxIspIAyedxxSn7FwWx2uJ1qne0hpnlBCdPfSC40N3T+vEU5Z7z6tJ1/OzGWwGIV8Xp36cnVw/6PudcdAUVlZUUHJisGR53zFe4+Rf/w4xZLzHh4Wnk5uaSk2P86PIfcNY3kt94O3buove3L2XWEw/SskXzrN1TY4nyaI1Tv34Sr7z8NEvfXkYiSAA33fRbRt8zkvz8fDaVJZ8DzZ//JtcMvYHDDy9i5nOPkUgkWFu6nquuvo5//Sv5vXvYYYVMfnAMrQoOZOOGMgZf9b+sXr02a/fW0DIxWmP7Ld9PO+c0v2VKWucLes7XB6M12rv7umB0xmhgl7vfYGbnAkNJjtboAYxx9+7BA8HFwInB4d4Eurp7WZ3n1FA6yYYoJ2f5/DKSnEcMTD85j5z6eZLzHKANycERS4Afufu2IFn/keTDvh3A5e6+KNj/CuDG4HCj3P3B+s6pl1BEJFoaYOIjd38ZeDn43KuWbRy4ppZ1E4GJe3NOJWcRiRZNfCQiEj5epbk1RETCRz1nEZEQ0mT7IiIhpJ6ziEj4uJKziEgI6YGgiEgIqecsIhJCSs4iIuHT0FNSNBYlZxGJFvWcRURCSMlZRCR8vEovoYiIhE80crOSs4hEi15CEREJIyVnEZEQUllDRCR8VNYQEQkhr1JyFhEJH5U1RETCJyJz7Ss5i0jEKDmLiISPes4iIiHkVdm+gsxQchaRSFHPWUQkhJScRUTCyC3bV5ARSs4iEilR6TnnZPsCREQyyROWdkuHmcXM7C0zezZY7mRm882sxMweN7MmQTw/WC4J1ndMOcbwIL7CzPqmc14lZxGJlETc0m5p+hmwPGX5TmC0ux8JbAYGB/HBwOYgPjrYDjM7GhgIHAP0A+4zs1h9J1VyFpFI8UT6rT5mVgScC4wPlg3oBTwZbDIZuCD4PCBYJlh/VrD9AGCqu5e7+wdACdC9vnMrOYtIpOxNWcPMhpjZopQ2ZLfD/R74BZ+9d3gw8Il79WjqNUBh8LkQWA0QrN8SbF8dr2GfWumBoIhEiu/FpHTuPg4YV9M6MzsP+NjdF5tZz0xc295QchaRSEn3QV8aTgW+aWb9gQOAA4F7gQIzyw16x0VAabB9KVAMrDGzXKAVsCkl/qnUfWqlsoaIREqmHgi6+3B3L3L3jiQf6M1x94uBucB3gs0GAc8En2cEywTr57i7B/GBwWiOTkBnYEF996Ges4hESgZ7zrX5JTDVzG4D3gImBPEJwMNmVgKUkUzouPu7ZjYNWAZUAde4e7y+k5jvTYHmc6jcuCoaf5ZAMqpph9OzfQkSQlUVpfucWd8/tm/aOeeId14I7euE6jmLSKRE5Q1BJWcRiZSE5tYQEQkfV3IWEQmfvXgtO9SUnEUkUhphtEajUHIWkUhRzVlEJIRUcxYRCaEGfnWj0Sg5i0ikqKwhIhJCCT0QFBEJH/Wc01T2nSsa+hSyH7qz3ZnZvgSJKD0QFBEJIfWcRURCKCKDNZScRSRa4olo/A0RJWcRiZSIzBiq5Cwi0eKo5iwiEjqJiBSdlZxFJFIS6jmLiISPyhoiIiEUV3IWEQkfjdYQEQkhJWcRkRBSzVlEJIQiMmOokrOIRIuG0omIhFA82xeQIUrOIhIpCYtGzzka0zeJiAR8L1pdzOwAM1tgZv80s3fN7DdBfJKZfWBmS4LWJYibmY0xsxIzW2pmJ6Yca5CZrQzaoHTuQz1nEYmUDA6lKwd6ufs2M8sDXjOz54N1P3f3J3fb/hygc9B6APcDPczsIOBmoBvJnwmLzWyGu2+u6+TqOYtIpCQs/VYXT9oWLOYFra4O9wDgoWC/eUCBmbUH+gKz3b0sSMizgX713YeSs4hEShxLu5nZEDNblNKGpB7LzGJmtgT4mGSCnR+sGhWULkabWX4QKwRWp+y+JojVFq+TyhoiEil7M87Z3ccB4+pYHwe6mFkBMN3MjgWGA+uBJsG+vwRGfv4rrpl6ziISKYm9aOly90+AuUA/d18XlC7KgQeB7sFmpUBxym5FQay2eJ2UnEUkUjI4WqNN0GPGzJoCvYH3gjoyZmbABcA7wS4zgEuDURsnA1vcfR3wAtDHzFqbWWugTxCrk8oaIhIpGXx9uz0w2cxiJDuy09z9WTObY2ZtAAOWAD8Ktp8J9AdKgB3A5QDuXmZmtwILg+1GuntZfSdXchaRSMnUUDp3XwqcUEO8Vy3bO3BNLesmAhP35vxKziISKfFovCCo5Cwi0aL5nEVEQkjJWUQkhOobhbG/UHIWkUjRZPsiIiGksoaISAhpsn0RkRBSWUNEJIRU1hARCSGN1hARCaFERNKzkrOIRIoeCIqIhJBqziIiIaTRGiIiIaSas4hICEUjNSs5i0jEqOYsIhJC8Yj0nZWcRSRS1HMWEQkhPRAUEQmhaKRmJWcRiRiVNUREQkgPBEVEQkg1Z/lMTg4HjR1LYuNGPhk+nJx27SgYMQJr1YqqFSvYcvvtUFUFQH7PnrS47DJwp/L999l6220AFNx1F3lHH03l22/zyfDhWbwZ2Vex/Dy+98SviTXJJSc3xsqZC/jHPU/R564raXtcJ8yMzR+sZ9awsVTuKOe4H/aiy6W98XiCyh27+NsNEyhbuZZ2x3+J3r8dnDyowRujp1PywqLs3tx+IBqpWck5I5pdeCFVH31ETvPmALS8+mq2P/kk5XPm0HLYMJr278/OGTOIFRbS/OKLKRs6FN+2DSsoqD7GjqlTIT+fZt/8ZpbuQjIlXl7JEwNvp3JHOTm5MQb+5SY+mPtPXh75KBXbdgJwxk0Xc8JlfVhw31957+k3WPrIHACO6H0iPW/6IU9dehcbV6zhkfNuwuMJmh9awKWzRvH+i2/i8ahUVRtGVHrOOdm+gP1dTps2NDn5ZHY+91x1rMmJJ1L+yisA7Jo1i/zTTgOg6XnnsfPpp/Ft2wDwTz6p3qfizTfxnTsb78KlQVXuKAcgJzdGTm4u7lQnZoDcA/JwTyaR1Hhe03wI4lW7KqoTcSw/79Ow1COxFy3M1HPeRy2HDmXb2LFYs2YAWKtWJLZtg3hyVtn4hg3E2rQBIFZcDEDrP/wBYjG2T5pExYIF2blwaVCWY/zwudso6NiWJQ/NZv2S9wHoe/cQOp15PJtWlvLKrY9Vb9/l0rPpetU5xPJymTbw9up4uy5H0Pfuqziw8BCev/bP6jWnwb/oPWczu7yOdUPMbJGZLXp47drPe4rQa3LKKSQ2b6bq//4vre0tFiNWVMTma69ly8iRHHj99ViLFg18lZINnnAePudXjOvxU9odfwQHH1UEwAvXj2PsSUMpK1nLl88/uXr7JQ+9yITTr+PVO6Zy8k8vqI6vX/I+k8++gUfPH0H3a84nlp/X2Ley34njabe6mNkBZrbAzP5pZu+a2W+CeCczm29mJWb2uJk1CeL5wXJJsL5jyrGGB/EVZtY3nfvYl7LGb2pb4e7j3L2bu3e7pEOHfThFuDU59ljyTz2VQ6ZOpdWIETQ54QRaDh1KTosWEIsBEGvThviGDUCyF13++usQj5NYv56q1auJFRZm8xakgZVv3cHqN5bRqedx1TFPOO/NeIPO/U/aY/v3ZszjyD5d94iXlaylcvsuDvlyUYNebxRksKxRDvRy9+OBLkA/MzsZuBMY7e5HApuB4Kktg4HNQXx0sB1mdjQwEDgG6AfcZ2ax+k5eZ3I2s6W1tLeBtvXfW7Rte+ABNn73u2wcOJAtI0dS8dZbbB01ioq33iL/jDMAOKBfv2RCBspfe40mXboAyfJHbnEx8XXrsnX50kCaHtSS/AOTZa7c/DwOP/1rlK1aR8Hhn33LHNn7RDaXJH+rLOj4WfxLZ3Vh84frATiwuA0WS36Ltiw8mIOO7MDW1Rsa6zb2Wwn3tFtdPGlbsJgXNAd6AU8G8cnABcHnAcEywfqzzMyC+FR3L3f3D4ASoHt991Ffzbkt0JfkT4dUBvyjvoN/UW0bO5ZWI0bQYvBgqlauZOfMmQBULFhAk27dOHjSJDyR4D9//jO+dSsArceMIfeww7CmTTnkiSfYetddVCxcmM3bkM+p+aEFnHPP1VgsB8sxVjw7n1UvLWHgX26iSYummMGGZf/ixV9NAuCEy/pw2GnHkKiMs2vLdmYNGwtA4UlH0f0n55OojOMJ56VfTWLn5m11nFlg74bSmdkQYEhKaJy7j0tZHwMWA0cCfwLeBz5x96pgkzXAp7/+FgKrAdy9ysy2AAcH8Xkp50jdp/Zr8zp+epjZBOBBd3+thnWPufsP6jvBv3v2jEZ1XjLqkVX69Vz2dN2/HtnnPzL1g8O/lXbOeeyj6Wmdz8wKgOnATcCkoHSBmRUDz7v7sWb2DtDP3dcE694HegC3APPc/ZEgPiHY58k9TpSizp6zuw+uY129iVlEpLE1xGgNd//EzOYCpwAFZpYb9J6LgNJgs1KgGFhjZrlAK2BTSvxTqfvUSuOcRSRSqvC0W13MrE3QY8bMmgK9geXAXOA7wWaDgGeCzzOCZYL1czxZmpgBDAxGc3QCOgP1jqHVOGcRiZQM9pzbA5ODunMOMM3dnzWzZcBUM7sNeAuYEGw/AXjYzEqAMpIjNHD3d81sGrAMqAKucfd4fSdXchaRSMnUazruvhQ4oYb4KmoYbeHuu4Dv1nKsUcCovTm/krOIREpdgxz2J0rOIhIpUZn4SMlZRCJFk+2LiISQes4iIiGkmrOISAhFZVJVJWcRiZSozOes5CwikaKas4hICMU9GoUNJWcRiRSVNUREQqi+SfT3F0rOIhIp0UjNSs4iEjF6ICgiEkJKziIiIaTRGiIiIaTRGiIiIaS5NUREQkg1ZxGREFLPWUQkhOIRmZdOyVlEIkVvCIqIhJBGa4iIhJB6ziIiIaSes4hICKnnLCISQnp9W0QkhFTWEBEJIY9Izzkn2xcgIpJJCTztVhczKzazuWa2zMzeNbOfBfFbzKzUzJYErX/KPsPNrMTMVphZ35R4vyBWYmY3pHMf6jmLSKRk8PXtKuA6d3/TzFoCi81sdrButLvfnbqxmR0NDASOAToAL5rZUcHqPwG9gTXAQjOb4e7L6jq5krOIREqmJj5y93XAuuDzf8xsOVBYxy4DgKnuXg58YGYlQPdgXYm7rwIws6nBtnUmZ5U1RCRS4olE2s3MhpjZopQ2pKZjmllH4ARgfhAaamZLzWyimbUOYoXA6pTd1gSx2uJ1UnIWkUjxvfnPfZy7d0tp43Y/npm1AP4CXOvuW4H7gSOALiR71r9riPtQWUNEIiWTU4aaWR7JxPyouz8VHP/fKesfAJ4NFkuB4pTdi4IYdcRrpZ6ziERKBkdrGDABWO7u96TE26ds9i3gneDzDGCgmeWbWSegM7AAWAh0NrNOZtaE5EPDGfXdh3rOIhIpGew5nwpcArxtZkuC2I3A982sC+DAh8DVwXnfNbNpJB/0VQHXuHscwMyGAi8AMWCiu79b38mVnEUkUuKJzLyE4u6vAVbDqpl17DMKGFVDfGZd+9VEyVlEIkV/Q1BEJIT0NwRFREJIU4aKiISQZqUTEQkh9ZxFREIoEZEpQ5WcRSRS9EBQRCSElJxFREIoGqkZLCo/ZfYHZjakplmv5ItNXxdSE0181LhqnCtWvvD0dSF7UHIWEQkhJWcRkRBScm5cqitKTfR1IXvQA0ERkRBSz1lEJISUnEVEQkjJuZGYWT8zW2FmJWZ2Q7avR7LPzCaa2cdm9k79W8sXjZJzIzCzGPAn4BzgaJJ/g+zo7F6VhMAkoF+2L0LCScm5cXQHStx9lbtXAFOBAVm+Jskyd38VKMv2dUg4KTk3jkJgdcrymiAmIlIjJWcRkRBScm4cpUBxynJREBMRqZGSc+NYCHQ2s05m1gQYCMzI8jWJSIgpOTcCd68ChgIvAMuBae7+bnavSrLNzKYAbwBfNrM1ZjY429ck4aHXt0VEQkg9ZxGREFJyFhEJISVnEZEQUnIWEQkhJWcRkRBSchYRCSElZxGREPp/6KQ9Wrwq474AAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 2 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
anastasiaslobodyanik's avatar
anastasiaslobodyanik committed
551
552
553
    }
   ],
   "source": [
554
555
    "cf_matrix = confusion_matrix(test_dataset.labels, out, normalize=None)\n",
    "sns.heatmap(cf_matrix, annot=True, fmt='d')"
anastasiaslobodyanik's avatar
anastasiaslobodyanik committed
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}