Commit 79743c47 authored by Chuanren Wu's avatar Chuanren Wu

consider induction

parent 7d236571
......@@ -9,6 +9,10 @@
#include <iterator>
#include <iostream>
enum ELIMINATION_ERROR : char {
E_SUCCESS = 0, E_1, E_2, E_3, E_4, E_5, E_6, E_7
};
template<class T> T atLeast1(T x)
{
return x >= 1 ? x : 1;
......@@ -59,46 +63,102 @@ struct LengthHelper
ConstraintExt *c;
bool isPlus;
double operator()()
// { return isPlus ? c->totalLengthPlus : c->totalLengthMinus; }
{ return isPlus ? c->totalLengthPlus - c->knownLengthPlus
: c->totalLengthMinus - c->knownLengthMinus; }
// { return isPlus ? c->totalLengthPlus : c->totalLengthMinus; }
// { return isPlus ? c->totalLengthPlus - c->knownLengthPlus : c->totalLengthMinus - c->knownLengthMinus; }
{ return isPlus ? c->plus.size() : c->minus.size(); }
};
static void insertNewValues(
const std::map<int, int> m,
std::vector<ConstraintExt> &vc
static ELIMINATION_ERROR insertNewValues(
int nTotal,
std::vector<ConstraintExt> &vc,
std::map<int, int> &m,
std::vector<int> &res
)
{
for (auto &c : vc) {
for (const auto &p : m) {
auto it = c.plus.find(p.first);
if (it != c.plus.cend()) {
c.plus.erase(it);
assert(c.knownPlus.find(p.first) == c.knownPlus.cend());
c.knownPlus.insert(p.first);
c.knownLengthPlus += p.second;
while (!m.empty()) {
auto np = *m.begin();
m.erase(np.first);
if (m.empty()) {
np.second = nTotal;
if (np.second < 1) {
return E_5;
}
}
nTotal -= np.second;
std::map<int, int> induct;
induct.insert(np);
while (!induct.empty()) {
auto p = *induct.begin();
induct.erase(p.first);
res[p.first] = p.second;
if (p.second < 1) {
return E_6;
}
it = c.minus.find(p.first);
if (it != c.minus.cend()) {
c.minus.erase(it);
assert(c.knownMinus.find(p.first) == c.knownMinus.cend());
c.knownMinus.insert(p.first);
c.knownLengthMinus += p.second;
for (auto &c : vc) {
auto it = c.plus.find(p.first);
if (it != c.plus.cend()) {
c.plus.erase(it);
c.knownPlus.insert(p.first);
c.knownLengthPlus += p.second;
}
it = c.minus.find(p.first);
if (it != c.minus.cend()) {
c.minus.erase(it);
c.knownMinus.insert(p.first);
c.knownLengthMinus += p.second;
}
// induction
if (c.plus.empty() && c.minus.size() == 1) {
int i = *c.minus.begin();
int v = c.knownLengthPlus - c.knownLengthMinus;
if (v < 1) {
return E_7;
}
induct[i] = v;
auto j = m.find(i);
if (j != m.end()) {
m.erase(j);
nTotal -= v;
}
} else if (c.minus.empty() && c.plus.size() == 1) {
int i = *c.plus.begin();
int v = c.knownLengthMinus - c.knownLengthPlus;
if (v < 1) {
return E_7;
}
induct[i] = v;
auto j = m.find(i);
if (j != m.end()) {
m.erase(j);
nTotal -= v;
}
}
}
}
}
return E_SUCCESS;
}
static std::map<int, int> eliminateSide(
/**
@param[in] vl
@param[in] lh
@param[out] nTotal
@param[out] m
*/
static ELIMINATION_ERROR eliminateSide(
const std::vector<double> &vl,
const LengthHelper &lh
const LengthHelper &lh,
int &nTotal,
std::map<int, int> &m
)
{
std::map<int, int> m;
nTotal = -1;
m.clear();
auto &c = *lh.c;
if (lh.isPlus) {
if (c.plus.empty()) {
return m;
return E_SUCCESS;
}
const double realTotal = c.minus.empty() ?
c.knownLengthMinus
......@@ -108,13 +168,10 @@ static std::map<int, int> eliminateSide(
c.totalLengthPlus,
atLeast1(c.totalLengthMinus - c.knownLengthMinus)*1.1
) - c.knownLengthPlus;
const int nTotal = c.minus.empty() ? c.knownLengthMinus
nTotal = c.minus.empty() ? c.knownLengthMinus
: std::ceil(realTotal);
if (nTotal < static_cast<int>(c.plus.size())) {
#ifndef NDEBUG
std::cerr << "failed to solve reason 0" << std::endl;
#endif
return m;
return E_2;
}
const double denominator = std::accumulate(
c.plus.cbegin(), c.plus.cend(), 0.0,
......@@ -129,15 +186,11 @@ static std::map<int, int> eliminateSide(
// since the "realTotal - size()" before several lines, we add 1
m[i] = 1 + std::floor(vl[i]*factor);
}
// make sure the sum is nTotal
m[*c.plus.cbegin()] = nTotal - std::accumulate(
std::next(c.plus.cbegin()), c.plus.cend(), 0,
[&m](int s, int i){ return s + m[i];} );
}
} else {
// the mirrored plus case
if (c.minus.empty()) {
return m;
return E_SUCCESS;
}
const double realTotal =
c.plus.empty() ?
......@@ -148,13 +201,10 @@ static std::map<int, int> eliminateSide(
c.totalLengthMinus,
atLeast1(c.totalLengthPlus - c.knownLengthPlus)*1.1
) - c.knownLengthMinus;
const int nTotal = c.plus.empty() ? c.knownLengthPlus
nTotal = c.plus.empty() ? c.knownLengthPlus
: std::ceil(realTotal);
if (nTotal < static_cast<int>(c.minus.size())) {
#ifndef NDEBUG
std::cerr << "failed to solve reason 1" << std::endl;
#endif
return m;
return E_2;
}
const double denominator = std::accumulate(
c.minus.cbegin(), c.minus.cend(), 0.0,
......@@ -169,48 +219,33 @@ static std::map<int, int> eliminateSide(
// since the "realTotal - size()" before several lines, we add 1
m[i] = 1 + std::floor(vl[i]*factor);
}
// make sure the sum is nTotal
m[*c.minus.cbegin()] = nTotal - std::accumulate(
std::next(c.minus.cbegin()), c.minus.cend(), 0,
[&m](int s, int i){ return s + m[i];} );
assert(m[*c.minus.cbegin()] > 0);
}
}
return m;
return E_SUCCESS;
}
static bool isValid(const std::vector<ConstraintExt> &vc)
static ELIMINATION_ERROR validate(const std::vector<ConstraintExt> &vc)
{
for (const auto &c : vc) {
if (c.plus.empty() && c.minus.empty()
&& c.knownLengthPlus != c.knownLengthMinus) {
#ifndef NDEBUG
std::cerr << "failed to solve reason 2" << std::endl;
#endif
return false;
return E_3;
}
// maybe a duplicated test, but can be used as short-circuit
if (c.plus.empty()
&& static_cast<int>(c.minus.size())
> (c.knownLengthPlus-c.knownLengthMinus)
) {
#ifndef NDEBUG
std::cerr << "failed to solve reason 3" << std::endl;
#endif
return false;
return E_4;
}
if (c.minus.empty()
&& static_cast<int>(c.plus.size())
> (c.knownLengthMinus-c.knownLengthPlus)
) {
#ifndef NDEBUG
std::cerr << "failed to solve reason 4" << std::endl;
#endif
return false;
return E_4;
}
}
return true;
return E_SUCCESS;
}
void sortLengthVector(std::vector<LengthHelper> &lh)
......@@ -233,16 +268,37 @@ static std::vector<int> eliminate(
std::fill(res.begin(), res.end(), -1989);
while (!lh.empty()) {
const auto newValues = eliminateSide(vl, *lh.rbegin());
insertNewValues(newValues, vc);
// check everytime ?
if (isValid(vc)) {
for (const auto &i : newValues) {
res[i.first] = i.second;
}
} else {
int nTotal;
std::map<int, int> m;
const auto errElimination = eliminateSide(vl, *lh.rbegin(), nTotal, m);
if (errElimination != E_SUCCESS) {
#ifndef NDEBUG
std::cerr << "failed to eliminate, reason "
<< errElimination << std::endl;
#endif
return std::vector<int>();
}
const auto errInsertion = insertNewValues(
nTotal, vc, /*std::move(m)*/m, res);
if (errInsertion != E_SUCCESS) {
#ifndef NDEBUG
std::cerr << "failed to insert, reason "
<< errInsertion << std::endl;
#endif
return std::vector<int>();
}
// check validate() everytime ?
const auto errValidation = validate(vc);
if (errValidation != E_SUCCESS) {
#ifndef NDEBUG
std::cerr << "failed to validate, reason "
<< errValidation << std::endl;
#endif
return std::vector<int>();
}
lh.pop_back();
// here to sort ?
sortLengthVector(lh);
......
......@@ -4,6 +4,26 @@
#include <cmath>
#include <gtest/gtest.h>
TEST(BC, SimpleLine1)
{
std::vector<Constraint> cs;
{
Constraint c;
c.plus(0);
c.minus(1);
cs.push_back(c);
}
std::vector<double> vl(2);
std::generate(vl.begin(), vl.end(), [](){ return 10.0;} );
auto v = discretizeImpl(vl, cs);
ASSERT_NE(0, v.size());
EXPECT_LE(10, v[0]);
EXPECT_LE(10, v[1]);
}
TEST(BC, SingleSegment1)
{
std::vector<Constraint> cs;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment