summaryrefslogtreecommitdiff
path: root/libs/ezsat/testbench.cc
diff options
context:
space:
mode:
Diffstat (limited to 'libs/ezsat/testbench.cc')
-rw-r--r--libs/ezsat/testbench.cc524
1 files changed, 524 insertions, 0 deletions
diff --git a/libs/ezsat/testbench.cc b/libs/ezsat/testbench.cc
new file mode 100644
index 00000000..cc0fe573
--- /dev/null
+++ b/libs/ezsat/testbench.cc
@@ -0,0 +1,524 @@
+/*
+ * ezSAT -- A simple and easy to use CNF generator for SAT solvers
+ *
+ * Copyright (C) 2013 Clifford Wolf <clifford@clifford.at>
+ *
+ * Permission to use, copy, modify, and/or distribute this software for any
+ * purpose with or without fee is hereby granted, provided that the above
+ * copyright notice and this permission notice appear in all copies.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
+ * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+ * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
+ * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+ * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+ * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
+ * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+ *
+ */
+
+#include "ezminisat.h"
+#include <stdio.h>
+
+struct xorshift128 {
+ uint32_t x, y, z, w;
+ xorshift128() {
+ x = 123456789;
+ y = 362436069;
+ z = 521288629;
+ w = 88675123;
+ }
+ uint32_t operator()() {
+ uint32_t t = x ^ (x << 11);
+ x = y; y = z; z = w;
+ w ^= (w >> 19) ^ t ^ (t >> 8);
+ return w;
+ }
+};
+
+bool test(ezSAT &sat, int assumption = 0)
+{
+ for (auto id : sat.assumed())
+ printf("%s\n", sat.to_string(id).c_str());
+ if (assumption)
+ printf("%s\n", sat.to_string(assumption).c_str());
+
+ std::vector<int> modelExpressions;
+ std::vector<bool> modelValues;
+
+ for (int id = 1; id <= sat.numLiterals(); id++)
+ if (sat.bound(id))
+ modelExpressions.push_back(id);
+
+ if (sat.solve(modelExpressions, modelValues, assumption)) {
+ printf("satisfiable:");
+ for (int i = 0; i < int(modelExpressions.size()); i++)
+ printf(" %s=%d", sat.to_string(modelExpressions[i]).c_str(), int(modelValues[i]));
+ printf("\n\n");
+ return true;
+ } else {
+ printf("not satisfiable.\n\n");
+ return false;
+ }
+}
+
+// ------------------------------------------------------------------------------------------------------------
+
+void test_simple()
+{
+ printf("==== %s ====\n\n", __PRETTY_FUNCTION__);
+
+ ezSAT sat;
+ sat.assume(sat.OR("A", "B"));
+ sat.assume(sat.NOT(sat.AND("A", "B")));
+ test(sat);
+}
+
+// ------------------------------------------------------------------------------------------------------------
+
+void test_basic_operators(ezSAT &sat, xorshift128 &rng, int iter, bool buildTrees, bool buildClusters, std::vector<bool> &log)
+{
+ int vars[6] = {
+ sat.VAR("A"), sat.VAR("B"), sat.VAR("C"),
+ sat.NOT("A"), sat.NOT("B"), sat.NOT("C")
+ };
+ for (int i = 0; i < iter; i++) {
+ int assumption = 0, op = rng() % 6, to = rng() % 6;
+ int a = vars[rng() % 6], b = vars[rng() % 6], c = vars[rng() % 6];
+ // printf("--> %d %d:%s %d:%s %d:%s\n", op, a, sat.to_string(a).c_str(), b, sat.to_string(b).c_str(), c, sat.to_string(c).c_str());
+ switch (op)
+ {
+ case 0:
+ assumption = sat.NOT(a);
+ break;
+ case 1:
+ assumption = sat.AND(a, b);
+ break;
+ case 2:
+ assumption = sat.OR(a, b);
+ break;
+ case 3:
+ assumption = sat.XOR(a, b);
+ break;
+ case 4:
+ assumption = sat.IFF(a, b);
+ break;
+ case 5:
+ assumption = sat.ITE(a, b, c);
+ break;
+ }
+ // printf(" --> %d:%s\n", to, sat.to_string(assumption).c_str());
+ if (buildTrees)
+ vars[to] = assumption;
+ if (!buildClusters)
+ sat.clear();
+ sat.assume(assumption);
+ if (sat.numCnfVariables() < 15) {
+ printf("%d:\n", int(log.size()));
+ log.push_back(test(sat));
+ } else {
+ // printf("** skipping large problem **\n");
+ }
+ }
+}
+
+void test_basic_operators(ezSAT &sat, std::vector<bool> &log)
+{
+ printf("-- %s --\n\n", __PRETTY_FUNCTION__);
+
+ xorshift128 rng;
+ test_basic_operators(sat, rng, 1000, false, false, log);
+ for (int i = 0; i < 100; i++)
+ test_basic_operators(sat, rng, 10, true, false, log);
+ for (int i = 0; i < 100; i++)
+ test_basic_operators(sat, rng, 10, false, true, log);
+}
+
+void test_basic_operators()
+{
+ printf("==== %s ====\n\n", __PRETTY_FUNCTION__);
+
+ ezSAT sat;
+ ezMiniSAT miniSat;
+ std::vector<bool> logSat, logMiniSat;
+
+ test_basic_operators(sat, logSat);
+ test_basic_operators(miniSat, logMiniSat);
+
+ if (logSat != logMiniSat) {
+ printf("Differences between logSat and logMiniSat:");
+ for (int i = 0; i < int(std::max(logSat.size(), logMiniSat.size())); i++)
+ if (i >= int(logSat.size()) || i >= int(logMiniSat.size()) || logSat[i] != logMiniSat[i])
+ printf(" %d", i);
+ printf("\n");
+ abort();
+ } else {
+ printf("Completed %d tests with identical results with ezSAT and ezMiniSAT.\n\n", int(logSat.size()));
+ }
+}
+
+// ------------------------------------------------------------------------------------------------------------
+
+void test_xorshift32_try(ezSAT &sat, uint32_t input_pattern)
+{
+ uint32_t output_pattern = input_pattern;
+ output_pattern ^= output_pattern << 13;
+ output_pattern ^= output_pattern >> 17;
+ output_pattern ^= output_pattern << 5;
+
+ std::vector<int> modelExpressions;
+ std::vector<int> forwardAssumptions, backwardAssumptions;
+ std::vector<bool> forwardModel, backwardModel;
+
+ sat.vec_append(modelExpressions, sat.vec_var("i", 32));
+ sat.vec_append(modelExpressions, sat.vec_var("o", 32));
+
+ sat.vec_append_unsigned(forwardAssumptions, sat.vec_var("i", 32), input_pattern);
+ sat.vec_append_unsigned(backwardAssumptions, sat.vec_var("o", 32), output_pattern);
+
+ if (!sat.solve(modelExpressions, backwardModel, backwardAssumptions)) {
+ printf("backward solving failed!\n");
+ abort();
+ }
+
+ if (!sat.solve(modelExpressions, forwardModel, forwardAssumptions)) {
+ printf("forward solving failed!\n");
+ abort();
+ }
+
+ printf("xorshift32 test with input pattern 0x%08x:\n", input_pattern);
+
+ printf("forward solution: input=0x%08x output=0x%08x\n",
+ (unsigned int)sat.vec_model_get_unsigned(modelExpressions, forwardModel, sat.vec_var("i", 32)),
+ (unsigned int)sat.vec_model_get_unsigned(modelExpressions, forwardModel, sat.vec_var("o", 32)));
+
+ printf("backward solution: input=0x%08x output=0x%08x\n",
+ (unsigned int)sat.vec_model_get_unsigned(modelExpressions, backwardModel, sat.vec_var("i", 32)),
+ (unsigned int)sat.vec_model_get_unsigned(modelExpressions, backwardModel, sat.vec_var("o", 32)));
+
+ if (forwardModel != backwardModel) {
+ printf("forward and backward results are inconsistend!\n");
+ abort();
+ }
+
+ printf("passed.\n\n");
+}
+
+void test_xorshift32()
+{
+ printf("==== %s ====\n\n", __PRETTY_FUNCTION__);
+
+ ezMiniSAT sat;
+ xorshift128 rng;
+
+ std::vector<int> bits = sat.vec_var("i", 32);
+
+ bits = sat.vec_xor(bits, sat.vec_shl(bits, 13));
+ bits = sat.vec_xor(bits, sat.vec_shr(bits, 17));
+ bits = sat.vec_xor(bits, sat.vec_shl(bits, 5));
+
+ sat.vec_set(bits, sat.vec_var("o", 32));
+
+ test_xorshift32_try(sat, 0);
+ test_xorshift32_try(sat, 314159265);
+ test_xorshift32_try(sat, rng());
+ test_xorshift32_try(sat, rng());
+ test_xorshift32_try(sat, rng());
+ test_xorshift32_try(sat, rng());
+}
+
+// ------------------------------------------------------------------------------------------------------------
+
+#define CHECK(_expr1, _expr2) check(#_expr1, _expr1, #_expr2, _expr2)
+
+void check(const char *expr1_str, bool expr1, const char *expr2_str, bool expr2)
+{
+ if (expr1 == expr2) {
+ printf("[ %s ] == [ %s ] .. ok (%s == %s)\n", expr1_str, expr2_str, expr1 ? "true" : "false", expr2 ? "true" : "false");
+ } else {
+ printf("[ %s ] != [ %s ] .. ERROR (%s != %s)\n", expr1_str, expr2_str, expr1 ? "true" : "false", expr2 ? "true" : "false");
+ abort();
+ }
+}
+
+void test_signed(int8_t a, int8_t b, int8_t c)
+{
+ ezSAT sat;
+
+ std::vector<int> av = sat.vec_const_signed(a, 8);
+ std::vector<int> bv = sat.vec_const_signed(b, 8);
+ std::vector<int> cv = sat.vec_const_signed(c, 8);
+
+ printf("Testing signed arithmetic using: a=%+d, b=%+d, c=%+d\n", int(a), int(b), int(c));
+
+ CHECK(a < b+c, sat.solve(sat.vec_lt_signed(av, sat.vec_add(bv, cv))));
+ CHECK(a <= b-c, sat.solve(sat.vec_le_signed(av, sat.vec_sub(bv, cv))));
+
+ CHECK(a > b+c, sat.solve(sat.vec_gt_signed(av, sat.vec_add(bv, cv))));
+ CHECK(a >= b-c, sat.solve(sat.vec_ge_signed(av, sat.vec_sub(bv, cv))));
+
+ printf("\n");
+}
+
+void test_unsigned(uint8_t a, uint8_t b, uint8_t c)
+{
+ ezSAT sat;
+
+ if (b < c)
+ b ^= c, c ^= b, b ^= c;
+
+ std::vector<int> av = sat.vec_const_unsigned(a, 8);
+ std::vector<int> bv = sat.vec_const_unsigned(b, 8);
+ std::vector<int> cv = sat.vec_const_unsigned(c, 8);
+
+ printf("Testing unsigned arithmetic using: a=%d, b=%d, c=%d\n", int(a), int(b), int(c));
+
+ CHECK(a < b+c, sat.solve(sat.vec_lt_unsigned(av, sat.vec_add(bv, cv))));
+ CHECK(a <= b-c, sat.solve(sat.vec_le_unsigned(av, sat.vec_sub(bv, cv))));
+
+ CHECK(a > b+c, sat.solve(sat.vec_gt_unsigned(av, sat.vec_add(bv, cv))));
+ CHECK(a >= b-c, sat.solve(sat.vec_ge_unsigned(av, sat.vec_sub(bv, cv))));
+
+ printf("\n");
+}
+
+void test_count(uint32_t x)
+{
+ ezSAT sat;
+
+ int count = 0;
+ for (int i = 0; i < 32; i++)
+ if (((x >> i) & 1) != 0)
+ count++;
+
+ printf("Testing bit counting using x=0x%08x (%d set bits) .. ", x, count);
+
+ std::vector<int> v = sat.vec_const_unsigned(x, 32);
+
+ std::vector<int> cv6 = sat.vec_const_unsigned(count, 6);
+ std::vector<int> cv4 = sat.vec_const_unsigned(count <= 15 ? count : 15, 4);
+
+ if (cv6 != sat.vec_count(v, 6, false)) {
+ fprintf(stderr, "FAILED 6bit-no-clipping test!\n");
+ abort();
+ }
+
+ if (cv4 != sat.vec_count(v, 4, true)) {
+ fprintf(stderr, "FAILED 4bit-clipping test!\n");
+ abort();
+ }
+
+ printf("ok.\n");
+}
+
+void test_arith()
+{
+ printf("==== %s ====\n\n", __PRETTY_FUNCTION__);
+
+ xorshift128 rng;
+
+ for (int i = 0; i < 100; i++)
+ test_signed(rng() % 19 - 10, rng() % 19 - 10, rng() % 19 - 10);
+
+ for (int i = 0; i < 100; i++)
+ test_unsigned(rng() % 10, rng() % 10, rng() % 10);
+
+ test_count(0x00000000);
+ test_count(0xffffffff);
+ for (int i = 0; i < 30; i++)
+ test_count(rng());
+
+ printf("\n");
+}
+
+// ------------------------------------------------------------------------------------------------------------
+
+void test_onehot()
+{
+ printf("==== %s ====\n\n", __PRETTY_FUNCTION__);
+ ezMiniSAT ez;
+
+ int a = ez.literal("a");
+ int b = ez.literal("b");
+ int c = ez.literal("c");
+ int d = ez.literal("d");
+
+ std::vector<int> abcd;
+ abcd.push_back(a);
+ abcd.push_back(b);
+ abcd.push_back(c);
+ abcd.push_back(d);
+
+ ez.assume(ez.onehot(abcd));
+
+ int solution_counter = 0;
+ while (1)
+ {
+ std::vector<bool> modelValues;
+ bool ok = ez.solve(abcd, modelValues);
+
+ if (!ok)
+ break;
+
+ printf("Solution: %d %d %d %d\n", int(modelValues[0]), int(modelValues[1]), int(modelValues[2]), int(modelValues[3]));
+
+ int count_hot = 0;
+ std::vector<int> sol;
+ for (int i = 0; i < 4; i++) {
+ if (modelValues[i])
+ count_hot++;
+ sol.push_back(modelValues[i] ? abcd[i] : ez.NOT(abcd[i]));
+ }
+ ez.assume(ez.NOT(ez.expression(ezSAT::OpAnd, sol)));
+
+ if (count_hot != 1) {
+ fprintf(stderr, "Wrong number of hot bits!\n");
+ abort();
+ }
+
+ solution_counter++;
+ }
+
+ if (solution_counter != 4) {
+ fprintf(stderr, "Wrong number of one-hot solutions!\n");
+ abort();
+ }
+
+ printf("\n");
+}
+
+void test_manyhot()
+{
+ printf("==== %s ====\n\n", __PRETTY_FUNCTION__);
+ ezMiniSAT ez;
+
+ int a = ez.literal("a");
+ int b = ez.literal("b");
+ int c = ez.literal("c");
+ int d = ez.literal("d");
+
+ std::vector<int> abcd;
+ abcd.push_back(a);
+ abcd.push_back(b);
+ abcd.push_back(c);
+ abcd.push_back(d);
+
+ ez.assume(ez.manyhot(abcd, 1, 2));
+
+ int solution_counter = 0;
+ while (1)
+ {
+ std::vector<bool> modelValues;
+ bool ok = ez.solve(abcd, modelValues);
+
+ if (!ok)
+ break;
+
+ printf("Solution: %d %d %d %d\n", int(modelValues[0]), int(modelValues[1]), int(modelValues[2]), int(modelValues[3]));
+
+ int count_hot = 0;
+ std::vector<int> sol;
+ for (int i = 0; i < 4; i++) {
+ if (modelValues[i])
+ count_hot++;
+ sol.push_back(modelValues[i] ? abcd[i] : ez.NOT(abcd[i]));
+ }
+ ez.assume(ez.NOT(ez.expression(ezSAT::OpAnd, sol)));
+
+ if (count_hot != 1 && count_hot != 2) {
+ fprintf(stderr, "Wrong number of hot bits!\n");
+ abort();
+ }
+
+ solution_counter++;
+ }
+
+ if (solution_counter != 4 + 4*3/2) {
+ fprintf(stderr, "Wrong number of one-hot solutions!\n");
+ abort();
+ }
+
+ printf("\n");
+}
+
+void test_ordered()
+{
+ printf("==== %s ====\n\n", __PRETTY_FUNCTION__);
+ ezMiniSAT ez;
+
+ int a = ez.literal("a");
+ int b = ez.literal("b");
+ int c = ez.literal("c");
+
+ int x = ez.literal("x");
+ int y = ez.literal("y");
+ int z = ez.literal("z");
+
+ std::vector<int> abc;
+ abc.push_back(a);
+ abc.push_back(b);
+ abc.push_back(c);
+
+ std::vector<int> xyz;
+ xyz.push_back(x);
+ xyz.push_back(y);
+ xyz.push_back(z);
+
+ ez.assume(ez.ordered(abc, xyz));
+
+ int solution_counter = 0;
+
+ while (1)
+ {
+ std::vector<int> modelVariables;
+ std::vector<bool> modelValues;
+
+ modelVariables.push_back(a);
+ modelVariables.push_back(b);
+ modelVariables.push_back(c);
+
+ modelVariables.push_back(x);
+ modelVariables.push_back(y);
+ modelVariables.push_back(z);
+
+ bool ok = ez.solve(modelVariables, modelValues);
+
+ if (!ok)
+ break;
+
+ printf("Solution: %d %d %d | %d %d %d\n",
+ int(modelValues[0]), int(modelValues[1]), int(modelValues[2]),
+ int(modelValues[3]), int(modelValues[4]), int(modelValues[5]));
+
+ std::vector<int> sol;
+ for (size_t i = 0; i < modelVariables.size(); i++)
+ sol.push_back(modelValues[i] ? modelVariables[i] : ez.NOT(modelVariables[i]));
+ ez.assume(ez.NOT(ez.expression(ezSAT::OpAnd, sol)));
+
+ solution_counter++;
+ }
+
+ if (solution_counter != 8+7+6+5+4+3+2+1) {
+ fprintf(stderr, "Wrong number of solutions!\n");
+ abort();
+ }
+
+ printf("\n");
+}
+
+// ------------------------------------------------------------------------------------------------------------
+
+
+int main()
+{
+ test_simple();
+ test_basic_operators();
+ test_xorshift32();
+ test_arith();
+ test_onehot();
+ test_manyhot();
+ test_ordered();
+ printf("Passed all tests.\n\n");
+ return 0;
+}
+