/* * ezSAT -- A simple and easy to use CNF generator for SAT solvers * * Copyright (C) 2013 Clifford Wolf * * Permission to use, copy, modify, and/or distribute this software for any * purpose with or without fee is hereby granted, provided that the above * copyright notice and this permission notice appear in all copies. * * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. * */ #include "ezminisat.h" #include struct xorshift128 { uint32_t x, y, z, w; xorshift128() { x = 123456789; y = 362436069; z = 521288629; w = 88675123; } uint32_t operator()() { uint32_t t = x ^ (x << 11); x = y; y = z; z = w; w ^= (w >> 19) ^ t ^ (t >> 8); return w; } }; bool test(ezSAT &sat, int assumption = 0) { std::vector modelExpressions; std::vector modelValues; for (int id = 1; id <= sat.numLiterals(); id++) if (sat.bound(id)) modelExpressions.push_back(id); if (sat.solve(modelExpressions, modelValues, assumption)) { printf("satisfiable:"); for (int i = 0; i < int(modelExpressions.size()); i++) printf(" %s=%d", sat.to_string(modelExpressions[i]).c_str(), int(modelValues[i])); printf("\n\n"); return true; } else { printf("not satisfiable.\n\n"); return false; } } // ------------------------------------------------------------------------------------------------------------ void test_simple() { printf("==== %s ====\n\n", __PRETTY_FUNCTION__); ezMiniSAT sat; sat.non_incremental(); sat.assume(sat.OR("A", "B")); sat.assume(sat.NOT(sat.AND("A", "B"))); test(sat); } // ------------------------------------------------------------------------------------------------------------ void test_xorshift32_try(ezSAT &sat, uint32_t input_pattern) { uint32_t output_pattern = input_pattern; output_pattern ^= output_pattern << 13; output_pattern ^= output_pattern >> 17; output_pattern ^= output_pattern << 5; std::vector modelExpressions; std::vector forwardAssumptions, backwardAssumptions; std::vector forwardModel, backwardModel; sat.vec_append(modelExpressions, sat.vec_var("i", 32)); sat.vec_append(modelExpressions, sat.vec_var("o", 32)); sat.vec_append_unsigned(forwardAssumptions, sat.vec_var("i", 32), input_pattern); sat.vec_append_unsigned(backwardAssumptions, sat.vec_var("o", 32), output_pattern); if (!sat.solve(modelExpressions, backwardModel, backwardAssumptions)) { printf("backward solving failed!\n"); abort(); } if (!sat.solve(modelExpressions, forwardModel, forwardAssumptions)) { printf("forward solving failed!\n"); abort(); } printf("xorshift32 test with input pattern 0x%08x:\n", input_pattern); printf("forward solution: input=0x%08x output=0x%08x\n", (unsigned int)sat.vec_model_get_unsigned(modelExpressions, forwardModel, sat.vec_var("i", 32)), (unsigned int)sat.vec_model_get_unsigned(modelExpressions, forwardModel, sat.vec_var("o", 32))); printf("backward solution: input=0x%08x output=0x%08x\n", (unsigned int)sat.vec_model_get_unsigned(modelExpressions, backwardModel, sat.vec_var("i", 32)), (unsigned int)sat.vec_model_get_unsigned(modelExpressions, backwardModel, sat.vec_var("o", 32))); if (forwardModel != backwardModel) { printf("forward and backward results are inconsistend!\n"); abort(); } printf("passed.\n\n"); } void test_xorshift32() { printf("==== %s ====\n\n", __PRETTY_FUNCTION__); ezMiniSAT sat; sat.keep_cnf(); xorshift128 rng; std::vector bits = sat.vec_var("i", 32); bits = sat.vec_xor(bits, sat.vec_shl(bits, 13)); bits = sat.vec_xor(bits, sat.vec_shr(bits, 17)); bits = sat.vec_xor(bits, sat.vec_shl(bits, 5)); sat.vec_set(bits, sat.vec_var("o", 32)); test_xorshift32_try(sat, 0); test_xorshift32_try(sat, 314159265); test_xorshift32_try(sat, rng()); test_xorshift32_try(sat, rng()); test_xorshift32_try(sat, rng()); test_xorshift32_try(sat, rng()); sat.printDIMACS(stdout, true); printf("\n"); } // ------------------------------------------------------------------------------------------------------------ #define CHECK(_expr1, _expr2) check(#_expr1, _expr1, #_expr2, _expr2) void check(const char *expr1_str, bool expr1, const char *expr2_str, bool expr2) { if (expr1 == expr2) { printf("[ %s ] == [ %s ] .. ok (%s == %s)\n", expr1_str, expr2_str, expr1 ? "true" : "false", expr2 ? "true" : "false"); } else { printf("[ %s ] != [ %s ] .. ERROR (%s != %s)\n", expr1_str, expr2_str, expr1 ? "true" : "false", expr2 ? "true" : "false"); abort(); } } void test_signed(int8_t a, int8_t b, int8_t c) { ezMiniSAT sat; std::vector av = sat.vec_const_signed(a, 8); std::vector bv = sat.vec_const_signed(b, 8); std::vector cv = sat.vec_const_signed(c, 8); printf("Testing signed arithmetic using: a=%+d, b=%+d, c=%+d\n", int(a), int(b), int(c)); CHECK(a < b+c, sat.solve(sat.vec_lt_signed(av, sat.vec_add(bv, cv)))); CHECK(a <= b-c, sat.solve(sat.vec_le_signed(av, sat.vec_sub(bv, cv)))); CHECK(a > b+c, sat.solve(sat.vec_gt_signed(av, sat.vec_add(bv, cv)))); CHECK(a >= b-c, sat.solve(sat.vec_ge_signed(av, sat.vec_sub(bv, cv)))); printf("\n"); } void test_unsigned(uint8_t a, uint8_t b, uint8_t c) { ezMiniSAT sat; if (b < c) b ^= c, c ^= b, b ^= c; std::vector av = sat.vec_const_unsigned(a, 8); std::vector bv = sat.vec_const_unsigned(b, 8); std::vector cv = sat.vec_const_unsigned(c, 8); printf("Testing unsigned arithmetic using: a=%d, b=%d, c=%d\n", int(a), int(b), int(c)); CHECK(a < b+c, sat.solve(sat.vec_lt_unsigned(av, sat.vec_add(bv, cv)))); CHECK(a <= b-c, sat.solve(sat.vec_le_unsigned(av, sat.vec_sub(bv, cv)))); CHECK(a > b+c, sat.solve(sat.vec_gt_unsigned(av, sat.vec_add(bv, cv)))); CHECK(a >= b-c, sat.solve(sat.vec_ge_unsigned(av, sat.vec_sub(bv, cv)))); printf("\n"); } void test_count(uint32_t x) { ezMiniSAT sat; int count = 0; for (int i = 0; i < 32; i++) if (((x >> i) & 1) != 0) count++; printf("Testing bit counting using x=0x%08x (%d set bits) .. ", x, count); std::vector v = sat.vec_const_unsigned(x, 32); std::vector cv6 = sat.vec_const_unsigned(count, 6); std::vector cv4 = sat.vec_const_unsigned(count <= 15 ? count : 15, 4); if (cv6 != sat.vec_count(v, 6, false)) { fprintf(stderr, "FAILED 6bit-no-clipping test!\n"); abort(); } if (cv4 != sat.vec_count(v, 4, true)) { fprintf(stderr, "FAILED 4bit-clipping test!\n"); abort(); } printf("ok.\n"); } void test_arith() { printf("==== %s ====\n\n", __PRETTY_FUNCTION__); xorshift128 rng; for (int i = 0; i < 100; i++) test_signed(rng() % 19 - 10, rng() % 19 - 10, rng() % 19 - 10); for (int i = 0; i < 100; i++) test_unsigned(rng() % 10, rng() % 10, rng() % 10); test_count(0x00000000); test_count(0xffffffff); for (int i = 0; i < 30; i++) test_count(rng()); printf("\n"); } // ------------------------------------------------------------------------------------------------------------ void test_onehot() { printf("==== %s ====\n\n", __PRETTY_FUNCTION__); ezMiniSAT ez; int a = ez.frozen_literal("a"); int b = ez.frozen_literal("b"); int c = ez.frozen_literal("c"); int d = ez.frozen_literal("d"); std::vector abcd; abcd.push_back(a); abcd.push_back(b); abcd.push_back(c); abcd.push_back(d); ez.assume(ez.onehot(abcd)); int solution_counter = 0; while (1) { std::vector modelValues; bool ok = ez.solve(abcd, modelValues); if (!ok) break; printf("Solution: %d %d %d %d\n", int(modelValues[0]), int(modelValues[1]), int(modelValues[2]), int(modelValues[3])); int count_hot = 0; std::vector sol; for (int i = 0; i < 4; i++) { if (modelValues[i]) count_hot++; sol.push_back(modelValues[i] ? abcd[i] : ez.NOT(abcd[i])); } ez.assume(ez.NOT(ez.expression(ezSAT::OpAnd, sol))); if (count_hot != 1) { fprintf(stderr, "Wrong number of hot bits!\n"); abort(); } solution_counter++; } if (solution_counter != 4) { fprintf(stderr, "Wrong number of one-hot solutions!\n"); abort(); } printf("\n"); } void test_manyhot() { printf("==== %s ====\n\n", __PRETTY_FUNCTION__); ezMiniSAT ez; int a = ez.frozen_literal("a"); int b = ez.frozen_literal("b"); int c = ez.frozen_literal("c"); int d = ez.frozen_literal("d"); std::vector abcd; abcd.push_back(a); abcd.push_back(b); abcd.push_back(c); abcd.push_back(d); ez.assume(ez.manyhot(abcd, 1, 2)); int solution_counter = 0; while (1) { std::vector modelValues; bool ok = ez.solve(abcd, modelValues); if (!ok) break; printf("Solution: %d %d %d %d\n", int(modelValues[0]), int(modelValues[1]), int(modelValues[2]), int(modelValues[3])); int count_hot = 0; std::vector sol; for (int i = 0; i < 4; i++) { if (modelValues[i]) count_hot++; sol.push_back(modelValues[i] ? abcd[i] : ez.NOT(abcd[i])); } ez.assume(ez.NOT(ez.expression(ezSAT::OpAnd, sol))); if (count_hot != 1 && count_hot != 2) { fprintf(stderr, "Wrong number of hot bits!\n"); abort(); } solution_counter++; } if (solution_counter != 4 + 4*3/2) { fprintf(stderr, "Wrong number of one-hot solutions!\n"); abort(); } printf("\n"); } void test_ordered() { printf("==== %s ====\n\n", __PRETTY_FUNCTION__); ezMiniSAT ez; int a = ez.frozen_literal("a"); int b = ez.frozen_literal("b"); int c = ez.frozen_literal("c"); int x = ez.frozen_literal("x"); int y = ez.frozen_literal("y"); int z = ez.frozen_literal("z"); std::vector abc; abc.push_back(a); abc.push_back(b); abc.push_back(c); std::vector xyz; xyz.push_back(x); xyz.push_back(y); xyz.push_back(z); ez.assume(ez.ordered(abc, xyz)); int solution_counter = 0; while (1) { std::vector modelVariables; std::vector modelValues; modelVariables.push_back(a); modelVariables.push_back(b); modelVariables.push_back(c); modelVariables.push_back(x); modelVariables.push_back(y); modelVariables.push_back(z); bool ok = ez.solve(modelVariables, modelValues); if (!ok) break; printf("Solution: %d %d %d | %d %d %d\n", int(modelValues[0]), int(modelValues[1]), int(modelValues[2]), int(modelValues[3]), int(modelValues[4]), int(modelValues[5])); std::vector sol; for (size_t i = 0; i < modelVariables.size(); i++) sol.push_back(modelValues[i] ? modelVariables[i] : ez.NOT(modelVariables[i])); ez.assume(ez.NOT(ez.expression(ezSAT::OpAnd, sol))); solution_counter++; } if (solution_counter != 8+7+6+5+4+3+2+1) { fprintf(stderr, "Wrong number of solutions!\n"); abort(); } printf("\n"); } // ------------------------------------------------------------------------------------------------------------ int main() { test_simple(); test_xorshift32(); test_arith(); test_onehot(); test_manyhot(); test_ordered(); printf("Passed all tests.\n\n"); return 0; }