mloptimizer.cpp 5.43 KB
Newer Older
1
2
3
#define PY_SSIZE_T_CLEAN
#include <Python.h>
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
4
#define PY_ARRAY_UNIQUE_SYMBOL KITRT_MLOPT_ARRAY_API
5
6
#include <numpy/arrayobject.h>

7
#include "common/config.h"
8
9
10
11
#include "optimizers/mloptimizer.h"
#include "toolboxes/errormessages.h"

MLOptimizer::MLOptimizer( Config* settings ) : OptimizerBase( settings ) {
12

13
    initialize_python();
14
15

    // initialize network
16
    std::string moduleName = "callNN_MK3";
17
18
19
20
21
22
23

    _pModule = PyImport_ImportModule( moduleName.c_str() );
    if( !_pModule ) {
        PyErr_Print();
        Py_DecRef( _pModule );
        ErrorMessages::Error( "'" + moduleName + "' can not be imported!", CURRENT_FUNCTION );
    }
24
25
}

26
27
MLOptimizer::~MLOptimizer() { finalize_python(); }

28
void MLOptimizer::Solve( Vector& lambda, Vector& u, const VectorVector& /*moments*/, unsigned /*idx_cell*/ ) {
29

30
31
    // Convert Vector to array
    const unsigned input_size = u.size();
32
    double* nn_input          = new double[u.size()];
33

34
    for( unsigned idx_sys = 0; idx_sys < input_size; idx_sys++ ) {
35
36
37
38
39
        nn_input[idx_sys] = u[idx_sys];
        // std::cout << nn_input[idx_sys] << ", ";
    }

    //  initialize_python();
40
41
42
    double* nn_output = callNetwork( input_size, nn_input );    //  nn_input;

    // std::cout << "Solution found in cell: " << idx_cell << "/8441 \n";
43
44
45
46
47

    for( unsigned i = 0; i < input_size; i++ ) {
        // std::cout << nn_output[i] << ", ";
        lambda[i] = nn_output[i];
    }
48
    //  std::cout << std::endl;
49
    delete[] nn_input;
50
51
}

52
void MLOptimizer::SolveMultiCell( VectorVector& lambda, VectorVector& u, const VectorVector& /*moments*/ ) {
53
54
55
56
57
58
59

    const unsigned batch_size = u.size();       // batch size = number of cells
    const unsigned sol_dim    = u[0].size();    // dimension of input vector = nTotalEntries

    const unsigned n_size = batch_size * sol_dim;    // length of input array

    // Covert input to array
60
    double* nn_input = new double[n_size];
61
62
63
64
65
66
67
68
69
70

    unsigned idx_input = 0;
    for( unsigned idx_cell = 0; idx_cell < batch_size; idx_cell++ ) {
        for( unsigned idx_sys = 0; idx_sys < sol_dim; idx_sys++ ) {
            nn_input[idx_input] = u[idx_cell][idx_sys];
            idx_input++;
        }
    }

    double* nn_output = callNetworkMultiCell( batch_size, sol_dim, nn_input );
71

72
73
74
75
76
77
78
    unsigned idx_output = 0;
    for( unsigned idx_cell = 0; idx_cell < batch_size; idx_cell++ ) {
        for( unsigned idx_sys = 0; idx_sys < sol_dim; idx_sys++ ) {
            lambda[idx_cell][idx_sys] = nn_output[idx_output];
            idx_output++;
        }
    }
79
80

    delete[] nn_output;
81
82
83
84
85
86
87
}

void MLOptimizer::init_numpy() {
    _import_array();    // Check, if this gives a mem Leak!
}

void MLOptimizer::initialize_python() {
88
    // Initialize the Python Interpreter
89
    std::string pyPath = KITRT_PYTHON_PATH;
90

91
    if( !Py_IsInitialized() ) {
92

93
94
95
96
97
98
        Py_InitializeEx( 0 );
        if( !Py_IsInitialized() ) {
            ErrorMessages::Error( "Python init failed!", CURRENT_FUNCTION );
        }
        PyRun_SimpleString( ( "import sys\nsys.path.append('" + pyPath + "')" ).c_str() );
    }
99

100
    // std::cout << "Python working directory is: " << pyPath << " \n";
101
102
    init_numpy();
}
103

104
105
106
107
void MLOptimizer::finalize_python() {
    Py_DecRef( _pModule );
    Py_Finalize();
}
108

109
double* MLOptimizer::callNetwork( const unsigned input_size, double* nn_input ) {
110

111
112
    PyObject *pArgs, *pReturn, *pFunc;    // *pModule,
    PyArrayObject* np_ret;
113

114
    pFunc = PyObject_GetAttrString( _pModule, "call_network" );
115
116
    if( !pFunc || !PyCallable_Check( pFunc ) ) {
        PyErr_Print();
117
        Py_DecRef( _pModule );
118
        Py_DecRef( pFunc );
119
        ErrorMessages::Error( "'call_network' is null or not callable!", CURRENT_FUNCTION );
120
121
    }

Steffen Schotthöfer's avatar
Steffen Schotthöfer committed
122
    long int dims[1] = { input_size };    // Why was this const?
123

124
    PyObject* inputArray = PyArray_SimpleNewFromData( 1, dims, NPY_DOUBLE, (void*)nn_input );
125

126
127
    pArgs = PyTuple_New( 1 );
    PyTuple_SetItem( pArgs, 0, reinterpret_cast<PyObject*>( inputArray ) );
128

129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
    // Call Python function
    pReturn = PyObject_CallObject( pFunc, pArgs );    // PyObject

    np_ret = reinterpret_cast<PyArrayObject*>( pReturn );    // Cast from PyObject to PyArrayObject

    double* nn_output = reinterpret_cast<double*>( PyArray_DATA( np_ret ) );    // Get Output

    // Finalizing
    Py_DecRef( pFunc );
    Py_DECREF( np_ret );

    return nn_output;
}

double* MLOptimizer::callNetworkMultiCell( const unsigned batch_size, const unsigned input_dim, double* nn_input ) {

    PyObject *pArgs, *pReturn, *pFunc;
    PyArrayObject* np_ret;

    pFunc = PyObject_GetAttrString( _pModule, "call_networkBatchwise" );
149
150
    if( !pFunc || !PyCallable_Check( pFunc ) ) {
        PyErr_Print();
151
        Py_DecRef( _pModule );
152
153
154
        Py_DecRef( pFunc );
        ErrorMessages::Error( "'call_network' is null or not callable!", CURRENT_FUNCTION );
    }
155

Steffen Schotthöfer's avatar
Steffen Schotthöfer committed
156
    long int dims[2] = { batch_size, input_dim };    // Why was this const?
157

158
    PyObject* inputArray = PyArray_SimpleNewFromData( 2, dims, NPY_DOUBLE, (void*)nn_input );
159
160
161
162

    pArgs = PyTuple_New( 1 );
    PyTuple_SetItem( pArgs, 0, reinterpret_cast<PyObject*>( inputArray ) );

163
    // Call Python function
164
165
166
167
    pReturn = PyObject_CallObject( pFunc, pArgs );    // PyObject

    np_ret = reinterpret_cast<PyArrayObject*>( pReturn );    // Cast from PyObject to PyArrayObject

168
    double* nn_output = reinterpret_cast<double*>( PyArray_DATA( np_ret ) );    // Get Output
169
170
171
172
173

    // Finalizing
    Py_DecRef( pFunc );
    Py_DECREF( np_ret );

174
    return nn_output;
175
}