Commit afb91428 authored by niklas.baumgarten's avatar niklas.baumgarten
Browse files

test refactoring

parent 1e81b8b4
......@@ -22,7 +22,8 @@ protected:
MonteCarlo mc;
TestMonteCarlo(const std::string &meshName, bool parallel, int commSplit = 0) :
TestMonteCarlo(const std::string &meshName, bool parallel,
bool onlyFine = true, int commSplit = 0) :
meshesCreator(MeshesCreator("Interval").
WithCommSplit(commSplit).
WithDistribute("RCB").
......@@ -33,13 +34,12 @@ protected:
WithProblem(GetParam()).
WithQuantity("GeneratorValue").
WithModel("DummyPDESolver")),
mc(MonteCarlo(level, dM, true, parallel, meshesCreator, pdeSolverCreator)) {
mc(MonteCarlo(level, dM, onlyFine, parallel, meshesCreator, pdeSolverCreator)) {
mc.Method();
}
void TearDown() {
PPM->Barrier(0);
PPM->ClearCommunicators(false);
}
};
......
......@@ -10,126 +10,114 @@
class TestWelfordAggregate : public Test {
protected:
int commSplit;
int commSplit;
int numSamples;
int numSamples;
WelfordAggregate aggregate;
WelfordAggregate aggregate;
Meshes *meshes;
Meshes *meshes;
TestWelfordAggregate() {
meshes = MeshesCreator("Interval").Create();
PPM->FullSplit();
}
TestWelfordAggregate() {
meshes = MeshesCreator("Interval").Create();
}
void TearDown() {
PPM->Barrier(0);
PPM->ClearCommunicators(false);
delete meshes;
}
void TearDown() {
PPM->Barrier(0);
delete meshes;
}
};
TEST_F(TestWelfordAggregate, TestWith4Samples) {
aggregate.UpdateSampleCounter(4);
if (PPM->Size(0) == 8) {
pout << "dMComm " << aggregate.dMcomm << " commSplit " << aggregate.commSplit
<< endl;
EXPECT_EQ(aggregate.dMcomm, 1);
EXPECT_EQ(aggregate.commSplit, 2);
}
if (PPM->Size(0) == 4) {
pout << "dMComm " << aggregate.dMcomm << " commSplit " << aggregate.commSplit
<< endl;
EXPECT_EQ(aggregate.dMcomm, 1);
EXPECT_EQ(aggregate.commSplit, 2);
}
if (PPM->Size(0) == 2) {
pout << "dMComm " << aggregate.dMcomm << " commSplit " << aggregate.commSplit
<< endl;
EXPECT_EQ(aggregate.dMcomm, 2);
EXPECT_EQ(aggregate.commSplit, 1);
}
if (PPM->Size(0) == 1) {
pout << "dMComm " << aggregate.dMcomm << " commSplit " << aggregate.commSplit
<< endl;
EXPECT_EQ(aggregate.dMcomm, 4);
EXPECT_EQ(aggregate.commSplit, 0);
}
pout << "Mcomm before update " << aggregate.Mcomm << endl;
while (aggregate.dMcomm != 0)
aggregate.Update(1.0, 1.0, 1.0);
pout << "Mcomm after update " << aggregate.Mcomm << endl;
aggregate.UpdateParallel();
pout << "M after update " << aggregate.M << endl;
EXPECT_EQ(aggregate.M, 4);
aggregate.UpdateSampleCounter(4);
if (PPM->Size(0) == 8) {
pout << "dMComm " << aggregate.dMcomm << " commSplit " << aggregate.commSplit << endl;
EXPECT_EQ(aggregate.dMcomm, 1);
EXPECT_EQ(aggregate.commSplit, 2);
}
if (PPM->Size(0) == 4) {
pout << "dMComm " << aggregate.dMcomm << " commSplit " << aggregate.commSplit << endl;
EXPECT_EQ(aggregate.dMcomm, 1);
EXPECT_EQ(aggregate.commSplit, 2);
}
if (PPM->Size(0) == 2) {
pout << "dMComm " << aggregate.dMcomm << " commSplit " << aggregate.commSplit << endl;
EXPECT_EQ(aggregate.dMcomm, 2);
EXPECT_EQ(aggregate.commSplit, 1);
}
if (PPM->Size(0) == 1) {
pout << "dMComm " << aggregate.dMcomm << " commSplit " << aggregate.commSplit
<< endl;
EXPECT_EQ(aggregate.dMcomm, 4);
EXPECT_EQ(aggregate.commSplit, 0);
}
pout << "Mcomm before update " << aggregate.Mcomm << endl;
while (aggregate.dMcomm != 0)
aggregate.Update(1.0, 1.0, 1.0);
pout << "Mcomm after update " << aggregate.Mcomm << endl;
aggregate.UpdateParallel();
pout << "M after update " << aggregate.M << endl;
EXPECT_EQ(aggregate.M, 4);
}
TEST_F(TestWelfordAggregate, TestWith8Samples) {
aggregate.UpdateSampleCounter(8);
if (PPM->Size(0) == 8) {
pout << "dMComm " << aggregate.dMcomm
<< " commSplit " << aggregate.commSplit << endl;
EXPECT_EQ(aggregate.dMcomm, 1);
EXPECT_EQ(aggregate.commSplit, 3);
}
if (PPM->Size(0) == 4) {
pout << "dMComm " << aggregate.dMcomm
<< " commSplit " << aggregate.commSplit << endl;
EXPECT_EQ(aggregate.dMcomm, 2);
EXPECT_EQ(aggregate.commSplit, 2);
}
if (PPM->Size(0) == 2) {
pout << "dMComm " << aggregate.dMcomm
<< " commSplit " << aggregate.commSplit << endl;
EXPECT_EQ(aggregate.dMcomm, 4);
EXPECT_EQ(aggregate.commSplit, 1);
}
if (PPM->Size(0) == 1) {
pout << "dMComm " << aggregate.dMcomm
<< " commSplit " << aggregate.commSplit << endl;
EXPECT_EQ(aggregate.dMcomm, 8);
EXPECT_EQ(aggregate.commSplit, 0);
}
pout << "Mcomm before update " << aggregate.Mcomm << endl;
while (aggregate.dMcomm != 0)
aggregate.Update(1.0, 1.0, 1.0);
pout << "Mcomm after update " << aggregate.Mcomm << endl;
aggregate.UpdateParallel();
pout << "M after update " << aggregate.M << endl;
EXPECT_EQ(aggregate.M, 8);
aggregate.UpdateSampleCounter(8);
if (PPM->Size(0) == 8) {
pout << "dMComm " << aggregate.dMcomm << " commSplit " << aggregate.commSplit << endl;
EXPECT_EQ(aggregate.dMcomm, 1);
EXPECT_EQ(aggregate.commSplit, 3);
}
if (PPM->Size(0) == 4) {
pout << "dMComm " << aggregate.dMcomm << " commSplit " << aggregate.commSplit << endl;
EXPECT_EQ(aggregate.dMcomm, 2);
EXPECT_EQ(aggregate.commSplit, 2);
}
if (PPM->Size(0) == 2) {
pout << "dMComm " << aggregate.dMcomm << " commSplit " << aggregate.commSplit << endl;
EXPECT_EQ(aggregate.dMcomm, 4);
EXPECT_EQ(aggregate.commSplit, 1);
}
if (PPM->Size(0) == 1) {
pout << "dMComm " << aggregate.dMcomm << " commSplit " << aggregate.commSplit << endl;
EXPECT_EQ(aggregate.dMcomm, 8);
EXPECT_EQ(aggregate.commSplit, 0);
}
pout << "Mcomm before update " << aggregate.Mcomm << endl;
while (aggregate.dMcomm != 0)
aggregate.Update(1.0, 1.0, 1.0);
pout << "Mcomm after update " << aggregate.Mcomm << endl;
aggregate.UpdateParallel();
pout << "M after update " << aggregate.M << endl;
EXPECT_EQ(aggregate.M, 8);
}
TEST_F(TestWelfordAggregate, TestWith1e6Samples) {
numSamples = 1e6;
NormalDistributionReal normalDist(*meshes);
aggregate.UpdateSampleCounter(numSamples);
if (PPM->Size(0) == 8) {
EXPECT_EQ(aggregate.commSplit, 3);
} else if (PPM->Size(0) == 4) {
EXPECT_EQ(aggregate.commSplit, 2);
} else if (PPM->Size(0) == 2) {
EXPECT_EQ(aggregate.commSplit, 1);
} else if (PPM->Size(0) == 1) {
EXPECT_EQ(aggregate.commSplit, 0);
}
EXPECT_EQ(aggregate.dMcomm, numSamples / PPM->Size(0));
pout << "MeanQcomm before update " << aggregate.MeanQcomm << endl;
while (aggregate.dMcomm != 0) {
normalDist.DrawSample(SampleID(0, aggregate.index(), false));
double val = normalDist.EvalSample();
aggregate.Update(val, val, val);
}
EXPECT_NEAR(aggregate.MeanQ, 0.0, sqrt(1.0 / numSamples));
EXPECT_NEAR(aggregate.MeanY, 0.0, sqrt(1.0 / numSamples));
numSamples = 1e6;
NormalDistributionReal normalDist(*meshes);
aggregate.UpdateSampleCounter(numSamples);
if (PPM->Size(0) == 8) EXPECT_EQ(aggregate.commSplit, 3);
else if (PPM->Size(0) == 4) EXPECT_EQ(aggregate.commSplit, 2);
else if (PPM->Size(0) == 2) EXPECT_EQ(aggregate.commSplit, 1);
else if (PPM->Size(0) == 1) EXPECT_EQ(aggregate.commSplit, 0);
EXPECT_EQ(aggregate.dMcomm, numSamples / PPM->Size(0));
pout << "MeanQcomm before update " << aggregate.MeanQcomm << endl;
while (aggregate.dMcomm != 0) {
normalDist.DrawSample(SampleID(0, aggregate.index(), false));
double val = normalDist.EvalSample();
aggregate.Update(val, val, val);
}
EXPECT_NEAR(aggregate.MeanQ, 0.0, sqrt(1.0 / numSamples));
EXPECT_NEAR(aggregate.MeanY, 0.0, sqrt(1.0 / numSamples));
// EXPECT_NEAR(aggregate.SVarQ, 1.0, sqrt(10.0 / numSamples));
// EXPECT_NEAR(aggregate.SVarY, 1.0, sqrt(10.0 / numSamples));
pout << "MeanQcomm after update " << aggregate.MeanQcomm << endl;
aggregate.UpdateParallel();
pout << "MeanQ after update " << aggregate.MeanQ << endl;
pout << "MeanQcomm after update " << aggregate.MeanQcomm << endl;
aggregate.UpdateParallel();
pout << "MeanQ after update " << aggregate.MeanQ << endl;
EXPECT_EQ(aggregate.M, numSamples);
EXPECT_EQ(aggregate.M, numSamples);
}
#endif //TESTWELFORDAGGREGATE_HPP
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment