summaryrefslogtreecommitdiff
path: root/books/workshops/2015/peng-greenstreet/support/z3_interface/RewriteExpt.py
diff options
context:
space:
mode:
Diffstat (limited to 'books/workshops/2015/peng-greenstreet/support/z3_interface/RewriteExpt.py')
-rw-r--r--books/workshops/2015/peng-greenstreet/support/z3_interface/RewriteExpt.py317
1 files changed, 317 insertions, 0 deletions
diff --git a/books/workshops/2015/peng-greenstreet/support/z3_interface/RewriteExpt.py b/books/workshops/2015/peng-greenstreet/support/z3_interface/RewriteExpt.py
new file mode 100644
index 0000000..0f478ac
--- /dev/null
+++ b/books/workshops/2015/peng-greenstreet/support/z3_interface/RewriteExpt.py
@@ -0,0 +1,317 @@
+# Copyright (C) 2015, University of British Columbia
+# Written (originally) by Mark Greenstreet (13th March, 2014)
+#
+# License: A 3-clause BSD license.
+# See the LICENSE file distributed with this software
+
+
+import collections
+import ACL2_to_Z3
+import z3
+
+def prod(stuff):
+ """ prod(stuff):
+ compute the product (i.e. reduce with '*') of the elements of 'stuff'.
+ 'stuff' must be iterable."""
+ return reduce(lambda x, y: x*y, stuff)
+
+def longVal(x):
+ """ longVal(x):
+ if 'x' is a z3 constant (i.e. function of arity 0) whose value is an integer,
+ then return that integer as an python long
+ else return 'None'"""
+ if(hasattr(x, 'as_long')): return x.as_long()
+ elif(hasattr(x, 'numerator_as_long')):
+ if(x.denominator_as_long() == 1L): return x.numerator_as_long()
+ return None
+# end longVal
+
+class to_smt_w_expt(ACL2_to_Z3.ACL22SMT):
+ class ExptRewriteFailure(Exception): pass
+
+ def __init__(self, *args):
+ super(to_smt_w_expt, self).__init__(*args)
+ # I'm making the exponent have sort Real instead of Int because
+ # the translator turns integerp to isReal! That's because the z3
+ # solver (understandably) chokes on mixed integer/real polynomials.
+ self.expt = z3.Function('expt', z3.RealSort(), z3.RealSort(), z3.RealSort())
+ self.b_sum = z3.Function('b_sum', z3.RealSort(), z3.RealSort(), z3.RealSort(), z3.RealSort(), z3.RealSort(), z3.RealSort(), z3.RealSort())
+ self.b_expt = z3.Function('b_expt', z3.RealSort(), z3.RealSort(), z3.RealSort())
+ self.maxPowExpand = 10
+
+ def simplify(self, expr, **kwargs):
+ if(z3.is_expr(expr)): return z3.simplify(expr, **kwargs)
+ else: # assume that expr has already been 'simplified' to a constant.
+ return expr
+
+ def reportFun(self, report=None):
+ def print_msg(*args):
+ print ''.join([str(a) for a in args])
+ return None
+ def dont_print_msg(*args):
+ return None
+ if((report is None) or (report is False)): return dont_print_msg
+ elif(report is True): return print_msg
+ else: return report
+
+ def get_expt_rules(self, expr_list, report=None):
+ if(len(expr_list) == 0): return []
+ else: hyps = expr_list[0]
+ workQ = collections.deque() # expt calls we still need to examine
+ allQ = collections.deque() # all expt calls that we've seen
+ report = self.reportFun(report)
+
+ def enqueue(v):
+ # z3 ASTs are unhashable; so we'll use a brute-force
+ # list for now -- beware of the quadratic time to build the
+ # allQ and workQ lists if we ever work on big examples.
+ report('enque(', v, ')')
+ for w in allQ:
+ if(v.eq(w)): # have we already seen v ?
+ report(' already seen, no work to do')
+ return
+ report(' appending ', v, ' to allQ and workQ')
+ allQ.append(v)
+ workQ.append(v)
+
+ def xpt(x, n):
+ v = self.expt(x, n)
+ enqueue(v)
+ return v
+
+ def lookfor_expt(v):
+ if(v is None): return
+ elif(hasattr(v, "decl") and hasattr(v, "children")):
+ # hopefully, v is a z3 expression
+ if(v.decl().eq(self.expt)):
+ x = v.children()[0]
+ n = v.children()[1]
+ enqueue(self.expt(x, self.simplify(n, som=True)))
+ for nu in v.children(): lookfor_expt(nu)
+
+ def expt_rules():
+ rules = collections.deque()
+ solver = z3.Solver()
+ solver.set('arith.nl', False)
+ solver.add(hyps)
+
+ def show(p):
+ report('trying to show(', p, '):')
+ report(' hypotheses = ', solver)
+ solver.push()
+ solver.add(z3.Not(p))
+ outcome = solver.check()
+ s1 = ' the negation is ' + str(outcome)
+ if(outcome == z3.unsat):
+ report(s1, "; therefore the original claim is valid")
+ elif(outcome == z3.sat):
+ report(s1, "\n here's a counter-example to ", p, "\n ", solver.model())
+ elif(outcome == z3.unknown):
+ report(s1, "; therefore, the original claim is undecided")
+ else:
+ report(s1, "; how'd that happen?")
+ solver.pop()
+ return outcome == z3.unsat
+
+ def add_rule(p):
+ report('add_rule(', p, ')')
+ rules.append(p)
+ solver.add(p)
+
+ while(len(workQ) > 0):
+ v = workQ.pop()
+ x = v.children()[0]
+ n = v.children()[1]
+
+ report('rewriting expt(', x, ', ', n, ')')
+
+ # Many of the rules below should have guards to ensure that we don't
+ # accidentally say expt(x, n) is defined when x==0 and n < 0.
+ # Rather that figuring out # all of the corner cases, I first check to
+ # see if (x == 0) and (n < 0) is satisfiable. If so, this code just
+ # throws an exception. I could probably work out a better error message
+ # later.
+ # Now that we know that expt(x, n) is well-defined, we still need to be careful.
+ # Consider expt(x, n+m) where x==0, n==3, and m==(-2). In this case, expt(x, n+m)
+ # is well-defined, but we can't conclude:
+ # expt(x, n+m) == expt(x, n) * expt(x, m)
+ # Rather than working out lots of side conditions (and probably making a mistake),
+ # I just check to see if implies(hyps, x > 0), and then plunge ahead without fear.
+ # Of course, this means I don't generate all of the rules that I could, but I'll
+ # do that later if this simple version turns out to be useful.
+
+ def expt_rewrite_const(x2, n2):
+ if(n2 == 0): return z3.intVal(1)
+ elif((0 < n2) and (n2 <= self.maxPowExpand)):
+ add_rule(v == prod(map(lambda _: x2, range(n2))))
+ elif((-self.maxPowExpand <= n2) and (n2 < 0)):
+ add_rule(v*prod(map(lambda _: x2, range(-n2))) == 1)
+ if(not show(z3.Or(x != 0, n >= 0))):
+ raise ExptRewriteFailure('possible attempt to raise 0 to a negative power')
+
+ x_is_pos = show(x > 0)
+ x_is_nz = x_is_pos or show(x != 0)
+ x_is_z = (not x_is_nz) and show(x == 0)
+
+ n_is_pos = show(n > 0)
+ n_is_neg = (not n_is_pos) and show(n < 0)
+ n_is_z = (not n_is_pos) and (not n_is_neg) and show(n == 0)
+
+ if(n_is_z or x_is_z):
+ if(n_is_z): add_rule(v == 1)
+ elif(n_is_pos): add_rule(v == 0)
+ else: add_rule(v == z3.If(n == 0, 1, 0))
+ continue
+ elif(x_is_pos):
+ x_lt_1 = show(x < 1)
+ x_gt_1 = (not x_lt_1) and show(x > 1)
+ if((not x_lt_1) and (not x_gt_1) and show(x == 1)):
+ add_rule(v == 1)
+ continue
+ add_rule(v > 0)
+ else:
+ add_rule(z3.Implies(x > 0, v > 0))
+ if(x_is_nz): add_rule(z != 0)
+ else: add_rule(z3.Implies(z3.Or(x != 0, n==0), v != 0))
+
+ if((x.decl().name() == '*') and (len(x.children()) > 1)): # expt(x0*x1*..., n)
+ add_rule(v == prod(map(lambda y: xpt(y, n), x.children())))
+ elif((n.decl().name() == '+') and (len(n.children()) > 1)): # expt(x, n0+n1+...)
+ add_rule(v == prod(map(lambda m: xpt(x, m), n.children())))
+ elif(n.decl().name() == '-'):
+ nn = n.children()
+ if(len(nn) == 0): pass # a variable named '-'
+ elif(len(nn) == 1): # expt(x, -n)
+ add_rule(z3.Implies(x != 0, v*xpt(x, nn[0]) == 1))
+ elif(len(nn) == 2): # expt(x, n-m)
+ add_rule(z3.Implies(x != 0, v*xpt(x, nn[1]) == xpt(x, nn[0])))
+ else: RewriteExptFailure("unexpected: '-' expression with more than two children")
+ elif(n.decl().name() == '*'): # expt(x, n0*n1*...)
+ # check to see if n0 is integer constants and not "too big".
+ # if so, replace it with repeated multiplication
+ nn = n.children()
+ if((len(nn) > 0) and not (longVal(nn[0]) is None)):
+ if(len(nn) == 1): ex = x
+ else: ex = xpt(x, prod(nn[1:]))
+ expt_rewrite_const(ex, longVal(nn[0]))
+ elif(not (longVal(n) is None)):
+ expt_rewrite_const(x, longVal(n))
+ else: # we can't think of a way to simplify it
+ if(x_lt_1 or x_gt_1):
+ if(n_is_pos or n_is_neg): pass
+ else: add_rule(z3.Implies(n == 0, v == 1))
+ else:
+ if(n_is_pos or n_is_neg): add_rule(z3.Implies(x==1, v == 1))
+ else: add_rule(z3.Implies(z3.Or(x==1, n == 0), v == 1))
+ if(x_is_pos):
+ if(x_lt_1):
+ if(n_is_pos): add_rule(v <= x)
+ elif(n_is_neg): add_rule(v*x >= 1)
+ else: add_rule(z3.And(
+ z3.Implies(n > 0, v <= x),
+ z3.Implies(n < 0, v*x >= 1)))
+ elif(x_gt_1):
+ if(n_is_pos): add_rule(v >= x)
+ elif(n_is_neg): add_rule(v*x <= 1)
+ else: add_rule(z3.And(
+ z3.Implies(n > 0, v >= x),
+ z3.Implies(n < 0, v*x <= 1)))
+ else: add_rule(z3.And(
+ z3.Implies(z3.And(x < 1, n > 0), v <= x),
+ z3.Implies(z3.And(x < 1, n < 0), v*x >= 1),
+ z3.Implies(z3.And(x > 1, n > 0), v >= x),
+ z3.Implies(z3.And(x > 1, n < 0), v*x <= 1)))
+ return rules
+ # end expt_rules
+
+ for x in expr_list: lookfor_expt(x)
+ return expt_rules()
+
+ # using z3's If function is simpler, and probably more efficient
+ # than introducing a new variable as is done in ACL2_translator
+ def ifx(self, condx, thenx, elsex):
+ return z3.If(condx, thenx, elsex)
+
+ # The ACL2 code should access Q as a method of the to_smt object and not
+ # as a separate method. I'm creating the method here so this will work
+ # right when the ACL2 code is modified. OTOH, ACL2_translator will probably
+ # get updated as well, in which case this methods will be redundant
+ def Q(self, numerator, denominator): return z3.Q(numerator, denominator)
+
+ def analyse_expt(self, hypotheses, conclusion=None, report=None):
+ report = self.reportFun(report)
+ expt_hyps = self.get_expt_rules([hypotheses, conclusion], report)
+ if(len(expt_hyps) == 0):
+ hyps = hypotheses
+ concl = conclusion
+ elif(conclusion is None):
+ hyps = z3.And(*expt_hyps)
+ concl = hypotheses
+ else:
+ hyps = z3.And(hypotheses, *expt_hyps)
+ concl = conclusion
+ simple_hyps = self.simplify(hyps)
+ simple_concl = self.simplify(concl)
+ return simple_hyps, simple_concl
+
+ # is x uninterpreted function instance
+ def is_uninterpreted_fun(self, x):
+ d = x.decl()
+ return(
+ all([hasattr(d, a) for a in ('__call__', 'arity', 'domain', 'kind', 'range')]) and
+ (d.kind() == z3.Z3_OP_UNINTERPRETED) and
+ d.arity() > 0)
+
+ # I'll assume that all arguments are z3 expressions except for possibly the
+ # last one. If the last one is a function, then it's the 'report' function
+ # for debugging.
+ def fun_to_var(self, exprs, report=None):
+ report = self.reportFun(report)
+ report('fun_to_var(', exprs, ', ', report, ')')
+
+ funQ = collections.deque() # uninterpreted functions we've seen
+
+ def helper(x):
+ if(self.is_uninterpreted_fun(x)):
+ match = [f[1] for f in funQ if f[0] is x]
+ if(len(match) == 1): # found a match
+ return match[0]
+ else:
+ rangeSort = x.decl().range()
+ varName = '$' + str(x)
+ if(rangeSort == z3.RealSort()): newVar = z3.Real(varName)
+ elif(rangeSort == z3.IntSort()): newVar = z3.Int(varName)
+ elif(rangeSort == z3.BoolSort()): newVar = z3.Bool(varName)
+ else:
+ raise ExptRewriteFailure(
+ 'unknown sort for range of uninterpreted function -- ' +
+ varName + ' returns a ' + rangeSort + ' ?')
+ funQ.append((x, newVar))
+ return newVar
+ else:
+ ch = x.children()
+ newch = self.fun_to_var(ch, report)
+ if(len(ch) != len(newch)):
+ raise ExptRewriteFailure('Internal error')
+ elif(len(newch) == x.decl().arity()):
+ return x.decl().__call__(*newch)
+ elif((x.decl().arity() == 2) and (len(newch) > 2)):
+ return reduce(x.decl(), newch)
+ else:
+ raise ExptRewriteFailure('Internal error')
+
+ newExprs = [helper(x) for x in exprs]
+ report('fun_to_var(', exprs, ') -> ', newExprs)
+ return newExprs
+
+ def prove(self, hypotheses, conclusion=None, report=None):
+ report = self.reportFun(report)
+
+ x_hyps, x_concl = self.analyse_expt(hypotheses, conclusion, report)
+ f_hyps, f_concl = self.fun_to_var([x_hyps, x_concl], report)[:]
+ hyps = z3.simplify(f_hyps); concl = z3.simplify(f_concl)
+
+ report('to_smt_w_expt.prove:')
+ report(' hypotheses = ', hyps)
+ report(' conclusion = ', concl)
+ return super(to_smt_w_expt, self).prove(hyps, concl)