summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/theory/arith/normal_form.cpp37
-rw-r--r--src/theory/arith/normal_form.h71
2 files changed, 82 insertions, 26 deletions
diff --git a/src/theory/arith/normal_form.cpp b/src/theory/arith/normal_form.cpp
index 3adb72f37..9ccf057b1 100644
--- a/src/theory/arith/normal_form.cpp
+++ b/src/theory/arith/normal_form.cpp
@@ -67,6 +67,12 @@ bool Variable::isLeafMember(Node n){
(Theory::isLeafOf(n, theory::THEORY_ARITH));
}
+VarList::VarList(Node n)
+ : NodeWrapper(n)
+{
+ Assert(isSorted(begin(), end()));
+}
+
bool Variable::isDivMember(Node n){
switch(n.getKind()){
case kind::DIVISION:
@@ -96,9 +102,15 @@ bool VarList::isMember(Node n) {
Node prev = *curr;
if(!Variable::isMember(prev)) return false;
+ Variable::VariableNodeCmp cmp;
+
while( (++curr) != end) {
if(!Variable::isMember(*curr)) return false;
- if(!(prev <= *curr)) return false;
+ // prev <= curr : accept
+ // !(prev <= curr) : reject
+ // !(!(prev > curr)) : reject
+ // curr < prev : reject
+ if((cmp(*curr, prev))) return false;
prev = *curr;
}
return true;
@@ -118,15 +130,16 @@ int VarList::cmp(const VarList& vl) const {
}
VarList VarList::parseVarList(Node n) {
- if(Variable::isMember(n)) {
- return VarList(Variable(n));
- } else {
- Assert(n.getKind() == kind::MULT);
- for(Node::iterator i=n.begin(), end = n.end(); i!=end; ++i) {
- Assert(Variable::isMember(*i));
- }
- return VarList(n);
- }
+ return VarList(n);
+ // if(Variable::isMember(n)) {
+ // return VarList(Variable(n));
+ // } else {
+ // Assert(n.getKind() == kind::MULT);
+ // for(Node::iterator i=n.begin(), end = n.end(); i!=end; ++i) {
+ // Assert(Variable::isMember(*i));
+ // }
+ // return VarList(n);
+ // }
}
VarList VarList::operator*(const VarList& other) const {
@@ -143,7 +156,9 @@ VarList VarList::operator*(const VarList& other) const {
otherBegin = other.internalBegin(),
otherEnd = other.internalEnd();
- merge_ranges(thisBegin, thisEnd, otherBegin, otherEnd, result);
+ Variable::VariableNodeCmp cmp;
+
+ merge_ranges(thisBegin, thisEnd, otherBegin, otherEnd, result, cmp);
Assert(result.size() >= 2);
Node mult = NodeManager::currentNM()->mkNode(kind::MULT, result);
diff --git a/src/theory/arith/normal_form.h b/src/theory/arith/normal_form.h
index f098d8b54..7e8ff556d 100644
--- a/src/theory/arith/normal_form.h
+++ b/src/theory/arith/normal_form.h
@@ -268,21 +268,43 @@ public:
}
bool operator<(const Variable& v) const {
- bool thisIsVariable = isMetaKindVariable();
- bool vIsVariable = v.isMetaKindVariable();
-
- if(thisIsVariable == vIsVariable){
- bool thisIsInteger = isIntegral();
- bool vIsInteger = v.isIntegral();
- if(thisIsInteger == vIsInteger){
- return getNode() < v.getNode();
+ VariableNodeCmp cmp;
+ return cmp(this->getNode(), v.getNode());
+
+ // bool thisIsVariable = isMetaKindVariable();
+ // bool vIsVariable = v.isMetaKindVariable();
+
+ // if(thisIsVariable == vIsVariable){
+ // bool thisIsInteger = isIntegral();
+ // bool vIsInteger = v.isIntegral();
+ // if(thisIsInteger == vIsInteger){
+ // return getNode() < v.getNode();
+ // }else{
+ // return thisIsInteger && !vIsInteger;
+ // }
+ // }else{
+ // return thisIsVariable && !vIsVariable;
+ // }
+ }
+
+ struct VariableNodeCmp {
+ bool operator()(Node n, Node m) const {
+ bool nIsVariable = n.isVar();
+ bool mIsVariable = m.isVar();
+
+ if(nIsVariable == mIsVariable){
+ bool nIsInteger = n.getType().isInteger();
+ bool mIsInteger = m.getType().isInteger();
+ if(nIsInteger == mIsInteger){
+ return n < m;
+ }else{
+ return nIsInteger && !mIsInteger;
+ }
}else{
- return thisIsInteger && !vIsInteger;
+ return nIsVariable && !mIsVariable;
}
- }else{
- return thisIsVariable && !vIsVariable;
}
- }
+ };
bool operator==(const Variable& v) const { return getNode() == v.getNode();}
@@ -419,6 +441,27 @@ static void merge_ranges(GetNodeIterator first1,
copy_range(first2, last2, result);
}
+template <class GetNodeIterator, class T, class Cmp>
+static void merge_ranges(GetNodeIterator first1,
+ GetNodeIterator last1,
+ GetNodeIterator first2,
+ GetNodeIterator last2,
+ std::vector<T>& result,
+ const Cmp& cmp) {
+
+ while(first1 != last1 && first2 != last2){
+ if( cmp(*first1, *first2) ){
+ result.push_back(*first1);
+ ++ first1;
+ }else{
+ result.push_back(*first2);
+ ++ first2;
+ }
+ }
+ copy_range(first1, last1, result);
+ copy_range(first2, last2, result);
+}
+
/**
* A VarList is a sorted list of variables representing a product.
* If the VarList is empty, it represents an empty product or 1.
@@ -437,9 +480,7 @@ private:
VarList() : NodeWrapper(Node::null()) {}
- VarList(Node n) : NodeWrapper(n) {
- Assert(isSorted(begin(), end()));
- }
+ VarList(Node n);
typedef expr::NodeSelfIterator internal_iterator;
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback