Commit ca6aa4ce authored by niklas.baumgarten's avatar niklas.baumgarten
Browse files

extended tests for WelfordAggregate.hpp

parent 06235e7b
Pipeline #148884 passed with stages
in 48 minutes and 23 seconds
...@@ -20,8 +20,8 @@ struct SampleCounter { ...@@ -20,8 +20,8 @@ struct SampleCounter {
} }
void UpdateParallel(int commSplit) { void UpdateParallel(int commSplit) {
M = PPM->SumOnCommSplit(Mcomm, 0) / PPM->Size(commSplit); M = PPM->SumAcrossComm(Mcomm, commSplit);
dM = PPM->SumOnCommSplit(dMcomm, 0) / PPM->Size(commSplit); dM = PPM->SumAcrossComm(dMcomm, commSplit);;
} }
friend Logging &operator<<(Logging &s, const SampleCounter &ctr) { friend Logging &operator<<(Logging &s, const SampleCounter &ctr) {
...@@ -237,7 +237,7 @@ public: ...@@ -237,7 +237,7 @@ public:
sVar.Y = Y2 / (ctr.M - 1); sVar.Y = Y2 / (ctr.M - 1);
// def parallel_variance(n_a, avg_a, M2_a, n_b, avg_b, M2_b): // def parallel_variance(n_a, avg_a, M2_a, n_b, avg_b, M2_b):
// n = n_a + n_b <- ctr.UpdateParallel // n = n_a + n_b
// delta = avg_b - avg_a // delta = avg_b - avg_a
// M2 = M2_a + M2_b + delta ** 2 * n_a * n_b / n // M2 = M2_a + M2_b + delta ** 2 * n_a * n_b / n
// var_ab = M2 / (n - 1) // var_ab = M2 / (n - 1)
......
...@@ -24,21 +24,36 @@ protected: ...@@ -24,21 +24,36 @@ protected:
} }
void TestSampleCounterParallel() { void TestSampleCounterParallel() {
pout << "Mcomm before update " << aggregate.ctr.Mcomm << endl; pout << "Before update: " << aggregate.ctr << endl;
while (aggregate.ctr.dMcomm != 0) aggregate.Update(0.0, 0.0, 0.0); while (aggregate.ctr.dMcomm != 0) aggregate.Update(0.0, 0.0, 0.0);
pout << "Mcomm after update " << aggregate.ctr.Mcomm << endl; pout << "After update: " << aggregate.ctr << endl;
aggregate.UpdateParallel(); aggregate.UpdateParallel();
pout << "M after update " << aggregate.ctr.M << endl; pout << "After parallel update: " << aggregate.ctr << endl;
EXPECT_EQ(aggregate.ctr.M, numSamples); EXPECT_EQ(aggregate.ctr.M, numSamples);
} }
/*
* Todo maybe use color
*/
void TestMeanParallel() { void TestMeanParallel() {
pout << "Mcomm before update " << aggregate.ctr.Mcomm << endl; pout << "Before update: " << aggregate.mean << endl;
while (aggregate.ctr.dMcomm != 0) aggregate.Update(0.0, 0.0, 0.0); while (aggregate.ctr.dMcomm != 0) aggregate.Update(0.0, PPM->Proc(0), 0.0);
pout << "Mcomm after update " << aggregate.ctr.Mcomm << endl; pout << "After update: " << aggregate.mean << endl;
aggregate.UpdateParallel(); aggregate.UpdateParallel();
pout << "M after update " << aggregate.ctr.M << endl; pout << "After parallel update: " << aggregate.mean << endl;
EXPECT_EQ(aggregate.ctr.M, numSamples); double sumGauss = (pow((PPM->Size(0) - 1), 2) + (PPM->Size(0) - 1)) / 2.0;
EXPECT_EQ(aggregate.mean.Q, sumGauss / PPM->Size(0));
}
void TestSVarParallel() {
// pout << "Before update: " << aggregate.sVar << endl;
// while (aggregate.ctr.dMcomm != 0) aggregate.Update(0.0, PPM->Proc(0), 0.0);
// pout << "After update: " << aggregate.sVar << endl;
// aggregate.UpdateParallel();
// pout << "After parallel update: " << aggregate.sVar << endl;
// double sumGauss = (pow((PPM->Size(0) - 1), 2) + (PPM->Size(0) - 1)) / 2.0;
// EXPECT_EQ(aggregate.sVar.Q, sumGauss / PPM->Size(0));
} }
void TearDown() { void TearDown() {
...@@ -56,26 +71,22 @@ public: ...@@ -56,26 +71,22 @@ public:
TEST_F(TestWelfordAggregate4Samples, TestSampleCounter) { TEST_F(TestWelfordAggregate4Samples, TestSampleCounter) {
switch(PPM->Size(0)) { switch(PPM->Size(0)) {
case 8: case 8:
pout << "dMComm " << aggregate.ctr.dMcomm pout << aggregate.ctr << " commSplit " << aggregate.commSplit << endl;
<< " commSplit " << aggregate.commSplit << endl;
EXPECT_EQ(aggregate.ctr.dMcomm, 1); EXPECT_EQ(aggregate.ctr.dMcomm, 1);
EXPECT_EQ(aggregate.commSplit, 2); EXPECT_EQ(aggregate.commSplit, 2);
break; break;
case 4: case 4:
pout << "dMComm " << aggregate.ctr.dMcomm pout << aggregate.ctr << " commSplit " << aggregate.commSplit << endl;
<< " commSplit " << aggregate.commSplit << endl;
EXPECT_EQ(aggregate.ctr.dMcomm, 1); EXPECT_EQ(aggregate.ctr.dMcomm, 1);
EXPECT_EQ(aggregate.commSplit, 2); EXPECT_EQ(aggregate.commSplit, 2);
break; break;
case 2: case 2:
pout << "dMComm " << aggregate.ctr.dMcomm pout << aggregate.ctr << " commSplit " << aggregate.commSplit << endl;
<< " commSplit " << aggregate.commSplit << endl;
EXPECT_EQ(aggregate.ctr.dMcomm, 2); EXPECT_EQ(aggregate.ctr.dMcomm, 2);
EXPECT_EQ(aggregate.commSplit, 1); EXPECT_EQ(aggregate.commSplit, 1);
break; break;
case 1: case 1:
pout << "dMComm " << aggregate.ctr.dMcomm pout << aggregate.ctr << " commSplit " << aggregate.commSplit << endl;
<< " commSplit " << aggregate.commSplit << endl;
EXPECT_EQ(aggregate.ctr.dMcomm, 4); EXPECT_EQ(aggregate.ctr.dMcomm, 4);
EXPECT_EQ(aggregate.commSplit, 0); EXPECT_EQ(aggregate.commSplit, 0);
break; break;
...@@ -85,39 +96,9 @@ TEST_F(TestWelfordAggregate4Samples, TestSampleCounter) { ...@@ -85,39 +96,9 @@ TEST_F(TestWelfordAggregate4Samples, TestSampleCounter) {
TestSampleCounterParallel(); TestSampleCounterParallel();
} }
TEST_F(TestWelfordAggregate4Samples, TestAverage) { TEST_F(TestWelfordAggregate4Samples, TestMean) { TestMeanParallel(); }
switch(PPM->Size(0)) {
case 8:
pout << "dMComm " << aggregate.ctr.dMcomm
<< " commSplit " << aggregate.commSplit << endl;
EXPECT_EQ(aggregate.ctr.dMcomm, 1);
EXPECT_EQ(aggregate.commSplit, 2);
break;
case 4:
pout << "dMComm " << aggregate.ctr.dMcomm
<< " commSplit " << aggregate.commSplit << endl;
EXPECT_EQ(aggregate.ctr.dMcomm, 1);
EXPECT_EQ(aggregate.commSplit, 2);
break;
case 2:
pout << "dMComm " << aggregate.ctr.dMcomm
<< " commSplit " << aggregate.commSplit << endl;
EXPECT_EQ(aggregate.ctr.dMcomm, 2);
EXPECT_EQ(aggregate.commSplit, 1);
break;
case 1:
pout << "dMComm " << aggregate.ctr.dMcomm
<< " commSplit " << aggregate.commSplit << endl;
EXPECT_EQ(aggregate.ctr.dMcomm, 4);
EXPECT_EQ(aggregate.commSplit, 0);
break;
default: Warning("No test case for this amount of processes")
break;
}
TestMeanParallel();
}
TEST_F(TestWelfordAggregate4Samples, TestSVar) { TestSVarParallel(); }
class TestWelfordAggregate8Samples : public TestWelfordAggregate { class TestWelfordAggregate8Samples : public TestWelfordAggregate {
public: public:
...@@ -129,26 +110,22 @@ public: ...@@ -129,26 +110,22 @@ public:
TEST_F(TestWelfordAggregate8Samples, TestSampleCounter) { TEST_F(TestWelfordAggregate8Samples, TestSampleCounter) {
switch(PPM->Size(0)) { switch(PPM->Size(0)) {
case 8: case 8:
pout << "dMComm " << aggregate.ctr.dMcomm pout << aggregate.ctr << " commSplit " << aggregate.commSplit << endl;
<< " commSplit " << aggregate.commSplit << endl;
EXPECT_EQ(aggregate.ctr.dMcomm, 1); EXPECT_EQ(aggregate.ctr.dMcomm, 1);
EXPECT_EQ(aggregate.commSplit, 3); EXPECT_EQ(aggregate.commSplit, 3);
break; break;
case 4: case 4:
pout << "dMComm " << aggregate.ctr.dMcomm pout << aggregate.ctr << " commSplit " << aggregate.commSplit << endl;
<< " commSplit " << aggregate.commSplit << endl;
EXPECT_EQ(aggregate.ctr.dMcomm, 2); EXPECT_EQ(aggregate.ctr.dMcomm, 2);
EXPECT_EQ(aggregate.commSplit, 2); EXPECT_EQ(aggregate.commSplit, 2);
break; break;
case 2: case 2:
pout << "dMComm " << aggregate.ctr.dMcomm pout << aggregate.ctr << " commSplit " << aggregate.commSplit << endl;
<< " commSplit " << aggregate.commSplit << endl;
EXPECT_EQ(aggregate.ctr.dMcomm, 4); EXPECT_EQ(aggregate.ctr.dMcomm, 4);
EXPECT_EQ(aggregate.commSplit, 1); EXPECT_EQ(aggregate.commSplit, 1);
break; break;
case 1: case 1:
pout << "dMComm " << aggregate.ctr.dMcomm pout << aggregate.ctr << " commSplit " << aggregate.commSplit << endl;
<< " commSplit " << aggregate.commSplit << endl;
EXPECT_EQ(aggregate.ctr.dMcomm, 8); EXPECT_EQ(aggregate.ctr.dMcomm, 8);
EXPECT_EQ(aggregate.commSplit, 0); EXPECT_EQ(aggregate.commSplit, 0);
break; break;
...@@ -158,6 +135,8 @@ TEST_F(TestWelfordAggregate8Samples, TestSampleCounter) { ...@@ -158,6 +135,8 @@ TEST_F(TestWelfordAggregate8Samples, TestSampleCounter) {
TestSampleCounterParallel(); TestSampleCounterParallel();
} }
TEST_F(TestWelfordAggregate8Samples, TestMean) { TestMeanParallel(); }
class TestWelfordAggregate1e6Samples : public TestWelfordAggregate { class TestWelfordAggregate1e6Samples : public TestWelfordAggregate {
public: public:
TestWelfordAggregate1e6Samples() : TestWelfordAggregate(1e6) { TestWelfordAggregate1e6Samples() : TestWelfordAggregate(1e6) {
...@@ -185,6 +164,7 @@ TEST_F(TestWelfordAggregate1e6Samples, TestSampleCounter) { ...@@ -185,6 +164,7 @@ TEST_F(TestWelfordAggregate1e6Samples, TestSampleCounter) {
TestSampleCounterParallel(); TestSampleCounterParallel();
} }
TEST_F(TestWelfordAggregate1e6Samples, TestMean) { TestMeanParallel(); }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment