Commit 2c4c79d6 authored by steffen.schotthoefer's avatar steffen.schotthoefer
Browse files

change save format of training history to json from pickle


Former-commit-id: 6aeaad62
parent 5b2d9c47
This source diff could not be displayed because it is too large. You can view the blob instead.
# imports
import tensorflow as tf
import numpy as np
import math
# Custom Loss
def custom_loss1dMBPrime(): # (label,prediciton)
def loss(u_input, alpha_pred):
return 0.5*tf.square(4*math.pi*np.sqrt(1/(4*np.pi))*tf.math.exp(alpha_pred*np.sqrt(1/(4*np.pi))) - u_input)
return loss
def initialize_network():
# Load model
model = tf.keras.models.load_model('saved_model/my_model')
model = tf.keras.models.load_model('saved_model_GPU/_EntropyLoss_1_300_M_0', custom_objects={ 'loss':custom_loss1dMBPrime })
# Check its architecture
model.summary()
return model
# make the network a gobal variable here
model = initialize_network()
def call_network(input):
inputNP = np.asarray([input])
......@@ -46,10 +51,10 @@ def call_networkBatchwise(input):
def main():
input = [[[0, 1, 2, 3], [2, 3, 4, 5]]]
input = [[1], [2], [3], [2], [3], [4], [5]]
# initialize_network()
print(call_network([0, 1, 2, 3]))
print(call_network(1))
print("-----")
print(call_networkBatchwise(input))
......
......@@ -84,6 +84,7 @@ void MLOptimizer::init_numpy() {
void MLOptimizer::initialize_python() {
// Initialize the Python Interpreter
std::string pyPath = RTSN_PYTHON_PATH;
if( !Py_IsInitialized() ) {
Py_InitializeEx( 0 );
......@@ -92,6 +93,8 @@ void MLOptimizer::initialize_python() {
}
PyRun_SimpleString( ( "import sys\nsys.path.append('" + pyPath + "')" ).c_str() );
}
std::cout << "Python working directory is: " << pyPath << " \n";
init_numpy();
}
......
......@@ -232,7 +232,7 @@ void MNSolver::Save( int currEnergy ) const {
//}
std::vector<std::vector<double>> scalarField( 1, _solverOutput );
// std::vector<std::vector<std::vector<double>>> results{ _outputFields };
std::vector<std::vector<std::vector<double>>> results{ _outputFields };
ExportVTK( _settings->GetOutputFile() + "_" + std::to_string( currEnergy ), results, fieldNames, _mesh );
}
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment