Commit cdfd7ab1 authored by niklas.baumgarten's avatar niklas.baumgarten
Browse files

Merge branch '35-add-weight-to-samplesolution-2' into 'feature'

Resolve "Add Weight to SampleSolution"

Closes #35

See merge request !49
parents 1706b108 bad8dc1e
Pipeline #158380 passed with stages
in 19 minutes and 39 seconds
...@@ -32,6 +32,36 @@ struct SampleCounter { ...@@ -32,6 +32,36 @@ struct SampleCounter {
} }
}; };
struct Weights {
double W = 0.0;
double W2 = 0.0;
double newW = 0.0;
double Wcomm = 0.0;
double W2comm = 0.0;
void Update(double _newW) {
newW = _newW;
Wcomm += newW;
W2comm += newW * newW;
}
void UpdateParallel(int commSplit) {
W = PPM->SumAcrossComm(Wcomm, commSplit);
W2 = PPM->SumAcrossComm(W2comm, commSplit);
}
friend Logging &operator<<(Logging &s, const Weights &wgt) {
return s << "W=" << wgt.W
<< " W2=" << wgt.W2
<< " WComm=" << wgt.Wcomm
<< " W2Comm=" << wgt.W2comm << endl;
}
};
struct Mean { struct Mean {
double C = 0.0; double C = 0.0;
...@@ -45,16 +75,21 @@ struct Mean { ...@@ -45,16 +75,21 @@ struct Mean {
double Ccomm = 0.0; double Ccomm = 0.0;
void Update(double dC, double dQ, double dY, SampleCounter ctr) { void Update(double dC, double dQ, double dY, SampleCounter ctr, Weights wgt) {
Ccomm += dC / ctr.Mcomm; Ccomm += dC * (1.0 / ctr.Mcomm);
Qcomm += dQ / ctr.Mcomm; Qcomm += dQ * (wgt.newW / wgt.Wcomm);
Ycomm += dY / ctr.Mcomm; Ycomm += dY * (wgt.newW / wgt.Wcomm);
} }
void UpdateParallel(SampleCounter ctr, int commSplit) { void UpdateParallel(SampleCounter ctr, Weights wgt, int commSplit) {
C = abs(PPM->SumAcrossComm(ctr.Mcomm * Ccomm, commSplit) / ctr.M); C = abs(PPM->SumAcrossComm(ctr.Mcomm * Ccomm, commSplit) / ctr.M);
Q = abs(PPM->SumAcrossComm(ctr.Mcomm * Qcomm, commSplit) / ctr.M); if (wgt.W == ctr.M) { // Monte Carlo Case
Y = abs(PPM->SumAcrossComm(ctr.Mcomm * Ycomm, commSplit) / ctr.M); Q = abs(PPM->SumAcrossComm(ctr.Mcomm * Qcomm, commSplit) / ctr.M);
Y = abs(PPM->SumAcrossComm(ctr.Mcomm * Ycomm, commSplit) / ctr.M);
} else { // Stochastic Collocation Case
Q = abs(PPM->SumAcrossComm(wgt.Wcomm * Qcomm, commSplit));
Y = abs(PPM->SumAcrossComm(wgt.Wcomm * Ycomm, commSplit));
}
} }
friend Logging &operator<<(Logging &s, const Mean &mean) { friend Logging &operator<<(Logging &s, const Mean &mean) {
...@@ -179,6 +214,8 @@ private: ...@@ -179,6 +214,8 @@ private:
public: public:
SampleCounter ctr; SampleCounter ctr;
Weights wgt;
Mean mean; Mean mean;
SVar sVar; SVar sVar;
...@@ -187,6 +224,8 @@ public: ...@@ -187,6 +224,8 @@ public:
Kurtosis kurtosis; Kurtosis kurtosis;
double W2 = 0.0;
double C2 = 0.0; double C2 = 0.0;
double Q2 = 0.0; double Q2 = 0.0;
...@@ -202,21 +241,23 @@ public: ...@@ -202,21 +241,23 @@ public:
UpdateSampleCounter(dM); UpdateSampleCounter(dM);
} }
void Update(const SampleSolution &fineSolution, const SampleSolution &coarseSolution) { void Update(const SampleSolution &fSol, const SampleSolution &cSol) {
double newC = fineSolution.Cost; double newW = fSol.W;
double newQ = fineSolution.Q; double newC = fSol.C;
double newY = fineSolution.Q - coarseSolution.Q; double newQ = fSol.Q;
Update(newC, newQ, newY); double newY = fSol.Q - cSol.Q;
Update(newW, newC, newQ, newY);
} }
void Update(double newC, double newQ, double newY) { void Update(double newW, double newC, double newQ, double newY) {
ctr.Update(); ctr.Update();
wgt.Update(newW);
double dC = newC - mean.Ccomm; double dC = newC - mean.Ccomm;
double dQ = newQ - mean.Qcomm; double dQ = newQ - mean.Qcomm;
double dY = newY - mean.Ycomm; double dY = newY - mean.Ycomm;
mean.Update(dC, dQ, dY, ctr); mean.Update(dC, dQ, dY, ctr, wgt);
double dC2 = newC - mean.Ccomm; double dC2 = newC - mean.Ccomm;
double dQ2 = newQ - mean.Qcomm; double dQ2 = newQ - mean.Qcomm;
...@@ -231,7 +272,8 @@ public: ...@@ -231,7 +272,8 @@ public:
void UpdateParallel() { void UpdateParallel() {
ctr.UpdateParallel(commSplit); ctr.UpdateParallel(commSplit);
mean.UpdateParallel(ctr, commSplit); wgt.UpdateParallel(commSplit);
mean.UpdateParallel(ctr, wgt, commSplit);
C2 = PPM->SumAcrossComm(C2comm, commSplit); C2 = PPM->SumAcrossComm(C2comm, commSplit);
Q2 = PPM->SumAcrossComm(Q2comm, commSplit); Q2 = PPM->SumAcrossComm(Q2comm, commSplit);
......
...@@ -40,15 +40,18 @@ public: ...@@ -40,15 +40,18 @@ public:
} }
}; };
// Todo remove init
struct SampleSolution { struct SampleSolution {
SampleID id; SampleID id;
public: public:
double Q; double Q; // Quantity of interest of sample solution
double Cost; double C; // Cost to compute sample solution
Vector U; double W; // Weight of the sample solution
Vector U; // Finite element coefficient vector of sample solution
SampleSolution(IDiscretization *disc, const std::string &name = "U") : SampleSolution(IDiscretization *disc, const std::string &name = "U") :
U(Vector((*disc))) { U(Vector((*disc))) {
...@@ -78,8 +81,9 @@ public: ...@@ -78,8 +81,9 @@ public:
void Init() { void Init() {
Q = 0.0; Q = 0.0;
Cost = 0.0; C = 0.0;
U = 0.0; U = 0.0;
W = 0.0;
} }
}; };
......
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
#include "SampleGenerator.hpp" #include "SampleGenerator.hpp"
//template<typename T>
class SparseGridGenerator : public SampleGenerator<RVector> { class SparseGridGenerator : public SampleGenerator<RVector> {
protected: protected:
TasGrid::TasmanianSparseGrid grid; TasGrid::TasmanianSparseGrid grid;
...@@ -20,7 +19,7 @@ protected: ...@@ -20,7 +19,7 @@ protected:
double weight = 0.0; double weight = 0.0;
RVector sample {}; RVector sample{};
std::vector<double> weights{}; std::vector<double> weights{};
...@@ -68,6 +67,10 @@ public: ...@@ -68,6 +67,10 @@ public:
return grid.getNumPoints(); return grid.getNumPoints();
} }
double SumOfWeights() const {
return 2.0 * dimension;
}
int GetStochDimension() { int GetStochDimension() {
return grid.getNumDimensions(); return grid.getNumDimensions();
} }
...@@ -76,20 +79,6 @@ public: ...@@ -76,20 +79,6 @@ public:
return sample; return sample;
} }
// todo move to stochastic collocation class
double Quadrature(double func(double, double)) {
double I = 0.0;
std::vector<double> points = grid.getPoints();
std::vector<double> weights = GetWeights();
int num_points = GetNumPoints();
for (int i = 0; i < num_points; i++) {
double x = points[i * dimension];
double y = points[i * dimension + 1];
I += weights[i] * func(x, y);
}
return I;
}
string Name() const override { return "SparseGridGenerator"; }; string Name() const override { return "SparseGridGenerator"; };
}; };
......
...@@ -23,7 +23,7 @@ void EllipticPDESolver::computeQ(SampleSolution &solution) { ...@@ -23,7 +23,7 @@ void EllipticPDESolver::computeQ(SampleSolution &solution) {
} }
void EllipticPDESolver::computeCost(SampleSolution &solution) { void EllipticPDESolver::computeCost(SampleSolution &solution) {
if (costMeasure == "size") solution.Cost = solution.U.size(); if (costMeasure == "size") solution.C = solution.U.size();
// else if (costMeasure == "time") solution.Cost = solution.U.size(); // Todo // else if (costMeasure == "time") solution.Cost = solution.U.size(); // Todo
else Exit("Cost measure not implemented") else Exit("Cost measure not implemented")
} }
......
...@@ -33,18 +33,19 @@ protected: ...@@ -33,18 +33,19 @@ protected:
virtual void plotSolution(SampleSolution &solution) = 0; virtual void plotSolution(SampleSolution &solution) = 0;
void weightSolution(SampleSolution &solution) const { void weightSolution(SampleSolution &solution) const {
solution.Q = GetProblem()->SampleWeight(solution.id) * solution.Q; solution.W = GetProblem()->SampleWeight(solution.id);
} }
public: public:
PDESolver(const Meshes &meshes, const std::string &quantity, PDESolver(const Meshes &meshes,
const std::string &quantity,
const std::string &costMeasure) : const std::string &costMeasure) :
meshes(meshes), quantity(quantity), costMeasure(costMeasure) { meshes(meshes), quantity(quantity), costMeasure(costMeasure) {
config.get("PDESolverVerbose", verbose); config.get("PDESolverVerbose", verbose);
config.get("PDESolverPlotting", plotting); config.get("PDESolverPlotting", plotting);
} }
virtual ~PDESolver() {}; virtual ~PDESolver() = default;
virtual void PrintInfo() const { virtual void PrintInfo() const {
if (verbose > 0) if (verbose > 0)
...@@ -61,8 +62,7 @@ public: ...@@ -61,8 +62,7 @@ public:
computeCost(solution); computeCost(solution);
plotSolution(solution); plotSolution(solution);
weightSolution(solution); weightSolution(solution);
// Todo other idea: Add weight as class member of solution vout(2) << "Q=" << solution.Q << " cost=" << solution.C << endl;
vout(2) << "Q=" << solution.Q << " cost=" << solution.Cost << endl;
mout.EndBlock(verbose <= 1); mout.EndBlock(verbose <= 1);
} }
...@@ -85,26 +85,19 @@ protected: ...@@ -85,26 +85,19 @@ protected:
void run(SampleSolution &solution) override {} void run(SampleSolution &solution) override {}
void computeQ(SampleSolution &solution) override { void computeQ(SampleSolution &solution) override {
if (quantity == "FunctionEvaluation") solution.Q = assemble->FunctionEvaluation(); solution.Q = assemble->FunctionEvaluation();
else Exit("Quantity of interest not implemented")
} }
void computeCost(SampleSolution &solution) override { void computeCost(SampleSolution &solution) override {
if (costMeasure == "size") solution.Cost = solution.U.size(); solution.C = 1.0; // Cost corresponds to one function evaluation
// else if (costMeasure == "time") solution.Cost = solution.U.size(); // Todo
else Exit("Cost measure not implemented")
} }
void plotSolution(SampleSolution &solution) override { void plotSolution(SampleSolution &solution) override {}
mpp::plot_mesh(solution.U.GetMesh());
}
public: public:
DummyPDESolver(IStochasticDummyAssemble *assemble, // Todo remove costMeasure and quantity
const Meshes &meshes, DummyPDESolver(IStochasticDummyAssemble *assemble, const Meshes &meshes) :
const std::string &quantity = "L2", PDESolver(meshes, "", ""), assemble(assemble) {}
const std::string &costMeasure = "size") :
PDESolver(meshes, quantity, costMeasure), assemble(assemble) {}
IAssemble *GetAssemble() const override { return assemble; } IAssemble *GetAssemble() const override { return assemble; }
......
...@@ -68,7 +68,7 @@ PDESolver *PDESolverCreator::Create(const Meshes &meshes) { ...@@ -68,7 +68,7 @@ PDESolver *PDESolverCreator::Create(const Meshes &meshes) {
new IStochasticDummyAssemble( new IStochasticDummyAssemble(
new LagrangeDiscretization(meshes, _degree), new LagrangeDiscretization(meshes, _degree),
CreateStochasticDummyProblem(_problem, meshes) CreateStochasticDummyProblem(_problem, meshes)
), meshes, _quantity, _costMeasure ), meshes
); );
Exit(_model + " not found") Exit(_model + " not found")
......
#ifndef ISTOCHASTICPROBLEM_HPP #ifndef ISTOCHASTICPROBLEM_HPP
#define ISTOCHASTICPROBLEM_HPP #define ISTOCHASTICPROBLEM_HPP
#include <utility>
#include "NormalDistribution.hpp" #include "NormalDistribution.hpp"
#include "UniformDistribution.hpp" #include "UniformDistribution.hpp"
#include "SparseGridGenerator.hpp" #include "SparseGridGenerator.hpp"
...@@ -19,9 +21,11 @@ public: ...@@ -19,9 +21,11 @@ public:
config.get("ProblemVerbose", verbose); config.get("ProblemVerbose", verbose);
} }
virtual double SampleWeight(const SampleID &id) { return 1.0; }
virtual void DrawSample(const SampleID &id) = 0; virtual void DrawSample(const SampleID &id) = 0;
virtual double SampleWeight(const SampleID &id) { return 1.0; } virtual double SumOfWeights() { return 1.0; }
virtual int NumOfSamples() { return 0; }; virtual int NumOfSamples() { return 0; };
...@@ -190,44 +194,48 @@ public: ...@@ -190,44 +194,48 @@ public:
} }
}; };
class SparseGrid1DGeneratorProblem : public StochasticDummyProblem { class SparseGridGeneratorProblem : public StochasticDummyProblem {
protected:
SparseGridGenerator generator; SparseGridGenerator generator;
public:
explicit SparseGrid1DGeneratorProblem(const Meshes &meshes) : SparseGridGeneratorProblem(const Meshes &meshes, SparseGridGenerator generator) :
StochasticDummyProblem(meshes), StochasticDummyProblem(meshes), generator(std::move(generator)) {}
generator(SparseGridGenerator(meshes, 1, 0, 6)) {}
void DrawSample(const SampleID &id) override { void DrawSample(const SampleID &id) override {
generator.DrawSample(id); generator.DrawSample(id);
} }
double FunctionEvaluation() override { double SampleWeight(const SampleID &id) override {
return std::exp(this->generator.EvalSample() * this->generator.EvalSample()); return generator.SampleWeight(id);
} }
string Name() const override { double SumOfWeights() override {
return "SparseGrid1DGeneratorProblem"; return generator.SumOfWeights();
}
int NumOfSamples() override {
return generator.GetNumPoints();
} }
}; };
class SparseGrid2DGeneratorProblem : public StochasticDummyProblem { class SparseGrid1DGeneratorProblem : public SparseGridGeneratorProblem {
SparseGridGenerator generator;
public: public:
explicit SparseGrid2DGeneratorProblem(const Meshes &meshes) : explicit SparseGrid1DGeneratorProblem(const Meshes &meshes) :
StochasticDummyProblem(meshes), SparseGridGeneratorProblem(meshes, SparseGridGenerator(meshes, 1, 0, 6)) {}
generator(SparseGridGenerator(meshes, 2, 0, 6)) {}
void DrawSample(const SampleID &id) override { double FunctionEvaluation() override {
generator.DrawSample(id); return std::exp(this->generator.EvalSample() * this->generator.EvalSample());
} }
double SampleWeight(const SampleID &id) override { string Name() const override {
return generator.SampleWeight(id); return "SparseGrid1DGeneratorProblem";
} }
};
int NumOfSamples() override { class SparseGrid2DGeneratorProblem : public SparseGridGeneratorProblem {
return generator.GetNumPoints(); public:
} explicit SparseGrid2DGeneratorProblem(const Meshes &meshes) :
SparseGridGeneratorProblem(meshes, SparseGridGenerator(meshes, 2, 0, 6)) {}
double FunctionEvaluation() override { double FunctionEvaluation() override {
RVector sample = this->generator.EvalSample(); RVector sample = this->generator.EvalSample();
......
...@@ -17,8 +17,7 @@ TEST_P(TestStochasticCollocationWithoutEpsilon, TestSeriellAgainstParallel) { ...@@ -17,8 +17,7 @@ TEST_P(TestStochasticCollocationWithoutEpsilon, TestSeriellAgainstParallel) {
mout.EndBlock(); mout.EndBlock();
mout << endl; mout << endl;
EXPECT_NEAR(scSeriell->aggregate.mean.Q * scSeriell->aggregate.ctr.M, EXPECT_NEAR(scSeriell->aggregate.mean.Q, GetParam().refValue, SC_TEST_TOLERANCE);
GetParam().refValue, SC_TEST_TOLERANCE);
} }
int main(int argc, char **argv) { int main(int argc, char **argv) {
......
...@@ -26,7 +26,7 @@ protected: ...@@ -26,7 +26,7 @@ protected:
void TestSampleCounterParallel() { void TestSampleCounterParallel() {
pout << "Before update: " << aggregate.ctr << endl; pout << "Before update: " << aggregate.ctr << endl;
while (aggregate.ctr.dMcomm != 0) aggregate.Update(0.0, 0.0, 0.0); while (aggregate.ctr.dMcomm != 0) aggregate.Update(1.0, 0.0, 0.0, 0.0);
pout << "After update: " << aggregate.ctr << endl; pout << "After update: " << aggregate.ctr << endl;
aggregate.UpdateParallel(); aggregate.UpdateParallel();
pout << "After parallel update: " << aggregate.ctr << endl; pout << "After parallel update: " << aggregate.ctr << endl;
...@@ -39,7 +39,7 @@ protected: ...@@ -39,7 +39,7 @@ protected:
void TestMeanParallel() { void TestMeanParallel() {
pout << "Before update: " << aggregate.mean << endl; pout << "Before update: " << aggregate.mean << endl;
while (aggregate.ctr.dMcomm != 0) aggregate.Update(0.0, PPM->Proc(0), 0.0); while (aggregate.ctr.dMcomm != 0) aggregate.Update(1.0, 0.0, PPM->Proc(0), 0.0);
pout << "After update: " << aggregate.mean << endl; pout << "After update: " << aggregate.mean << endl;
aggregate.UpdateParallel(); aggregate.UpdateParallel();
pout << "After parallel update: " << aggregate.mean << endl; pout << "After parallel update: " << aggregate.mean << endl;
......
#include "TestSparseGridGenerator.hpp" #include "TestSparseGridGenerator.hpp"
#include "SparseGridGenerator.hpp" #include "SparseGridGenerator.hpp"
#include "MeshesCreator.hpp" #include "MeshesCreator.hpp"
...@@ -70,7 +69,14 @@ TEST_P(TestClenshawCurtis, TestSumOfWeights) { ...@@ -70,7 +69,14 @@ TEST_P(TestClenshawCurtis, TestSumOfWeights) {
double sum = 0.0; double sum = 0.0;
for (auto &weight : generator.GetWeights()) for (auto &weight : generator.GetWeights())
sum += weight; sum += weight;
EXPECT_NEAR(sum, 2.0 * dimension, TEST_TOLERANCE); EXPECT_NEAR(sum, generator.SumOfWeights(), TEST_TOLERANCE);
}
TEST_P(TestClenshawCurtis, TestSumOfDrawnWeights) {
double sum = 0.0;
for (int i = 0; i < generator.GetNumPoints(); i++)
sum += generator.SampleWeight(SampleID(0, i, false));
EXPECT_NEAR(sum, generator.SumOfWeights(), TEST_TOLERANCE);
} }
int main(int argc, char **argv) { int main(int argc, char **argv) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment