Z3FOEncoding.scala 8.69 KB
Newer Older
Christian Müller's avatar
Christian Müller committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
package de.tum.workflows.toz3

import java.util

import com.microsoft.z3.{BoolExpr, Context, Expr, FuncDecl, Model, Sort, Symbol}
import com.typesafe.scalalogging.LazyLogging
import de.tum.workflows.foltl.FOLTL._

object Z3FOEncoding extends LazyLogging {

  val TIMEOUT = 60000 // in milliseconds

  def translate(f: Formula, ctx: Context) = {
    //    logger.info(s"Using formula:\n$f")
    fun_ctx.clear()
Christian Müller's avatar
src  
Christian Müller committed
16
    val constctx = buildVarCtx(ctx, Map(), List(), f.freeVars.toList)
Christian Müller's avatar
Christian Müller committed
17

Christian Müller's avatar
src  
Christian Müller committed
18
    val expr = toBoolZ3(ctx, f, constctx)
Christian Müller's avatar
Christian Müller committed
19 20 21 22 23 24
    //    logger.info(s"Built Z3 expression:\n$expr")
    expr
  }

  def toBoolZ3(ctx: Context, f: Formula, var_ctx: VarCtx): BoolExpr = {
    f match {
Christian Müller's avatar
src  
Christian Müller committed
25 26 27 28
      case f:Fun => {
        val name = f.encodeName()
        val params = f.params
        var fdecl = fun_ctx.get()
Christian Müller's avatar
Christian Müller committed
29 30

        if (fdecl == null) {
Christian Müller's avatar
src  
Christian Müller committed
31 32 33
          val sorts = params.map(x => var_ctx(x)._2.getSort())
          fdecl = ctx.mkFuncDecl(name, sorts.toArray, ctx.getBoolSort())
          fun_ctx.put(name, fdecl)
Christian Müller's avatar
Christian Müller committed
34 35
        }

Christian Müller's avatar
src  
Christian Müller committed
36
        // all variables are quantified or free, so they are part of var_ctx
Christian Müller's avatar
Christian Müller committed
37 38 39 40 41 42
        val all_args = params.map(x => var_ctx(x)._2)
        fdecl.apply(all_args: _*).asInstanceOf[BoolExpr]
      }

      case Exists(vars, f1) => {
        val names: Array[Symbol] = vars.map(v => ctx.mkSymbol(v.name)).toArray
Christian Müller's avatar
src  
Christian Müller committed
43 44
        val newctx = buildVarCtx(ctx, var_ctx, vars, List())
        val sorts = vars.map(newctx(_)._2.getSort()).toArray
Christian Müller's avatar
Christian Müller committed
45 46 47 48 49 50 51 52

        val e1 = toBoolZ3(ctx, f1, newctx)

        ctx.mkExists(sorts, names, e1, 0, null, null, null, null)
      }

      case Forall(vars, f1) => {
        val names: Array[Symbol] = vars.map(v => ctx.mkSymbol(v.name)).toArray
Christian Müller's avatar
src  
Christian Müller committed
53 54
        val newctx = buildVarCtx(ctx, var_ctx, vars, List())
        val sorts = vars.map(newctx(_)._2.getSort()).toArray
Christian Müller's avatar
Christian Müller committed
55 56 57 58 59 60 61 62 63 64 65 66

        val e1 = toBoolZ3(ctx, f1, newctx)

        ctx.mkForall(sorts, names, e1, 0, null, null, null, null)
      }

      case And(f1, f2) => {
        val e1 = toBoolZ3(ctx, f1, var_ctx)
        val e2 = toBoolZ3(ctx, f2, var_ctx)
        ctx.mkAnd(e1, e2)
      }

Christian Müller's avatar
Christian Müller committed
67 68 69 70 71 72 73 74
      case Equiv(f1, f2) => {
        logger.warn("Sending equiv <-> to Z3")
        val e1 = toBoolZ3(ctx, f1, var_ctx)
        val e2 = toBoolZ3(ctx, f2, var_ctx)
        ctx.mkEq(e1, e2)
      }

      case Equal(f1, f2) => {
Christian Müller's avatar
Christian Müller committed
75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
        val e1 = toSortedZ3(ctx, f1, var_ctx)
        val e2 = toSortedZ3(ctx, f2, var_ctx)
        ctx.mkEq(e1, e2)
      }

      case Or(f1, f2) => {
        val e1 = toBoolZ3(ctx, f1, var_ctx)
        val e2 = toBoolZ3(ctx, f2, var_ctx)
        ctx.mkOr(e1, e2)
      }

      case Implies(f1, f2) => {
        val e1 = toBoolZ3(ctx, f1, var_ctx)
        val e2 = toBoolZ3(ctx, f2, var_ctx)
        ctx.mkImplies(e1, e2)
      }

      case Neg(f1) => {
        val e = toBoolZ3(ctx, f1, var_ctx)
        ctx.mkNot(e)
      }

      case True => {
        ctx.mkTrue()
      }
      case False => {
        ctx.mkFalse()
      }
    }
  }

  def toSortedZ3(ctx: Context, f: Formula, var_ctx: VarCtx): Expr = {
    f match {
      case v:Var => {
        var_ctx(v)._2
      }
      case _ => {
        toBoolZ3(ctx, f, var_ctx)
      }
    }
  }

Christian Müller's avatar
src  
Christian Müller committed
117
  type VarCtx = Map[Var, (Option[Int], Expr)]
Christian Müller's avatar
Christian Müller committed
118 119 120 121

  // TODO: how to not make this static?
  val fun_ctx = new util.HashMap[String, FuncDecl]()

Christian Müller's avatar
src  
Christian Müller committed
122 123 124 125 126
  def buildVarCtx(ctx: Context, var_ctx: VarCtx, vars: List[Var], const: List[Var]): VarCtx = {

    val sorts = (vars ++ const).map(_.typ).toSet
    val z3sorts = (for (s <- sorts) yield {
      val sort = if (s.equals("Int")) {
Christian Müller's avatar
Christian Müller committed
127
        ctx.getIntSort()
Christian Müller's avatar
src  
Christian Müller committed
128
      } else if (s.equals("Bool")) {
Christian Müller's avatar
Christian Müller committed
129 130
        ctx.getBoolSort()
      } else {
Christian Müller's avatar
src  
Christian Müller committed
131
        ctx.mkUninterpretedSort(s) // TODO: finite domain? sort size?
Christian Müller's avatar
Christian Müller committed
132
      }
Christian Müller's avatar
src  
Christian Müller committed
133 134 135 136 137 138
      s -> sort
    }).toMap

    val indices = (vars.size - 1) to 0 by -1
    val newexprs = (for ((v, i) <- vars.zip(indices)) yield {
      v -> (Some(i), ctx.mkBound(i, z3sorts(v.typ)))
Christian Müller's avatar
Christian Müller committed
139 140
    }) toMap

Christian Müller's avatar
src  
Christian Müller committed
141 142 143 144 145
    // constants
    val constexprs = for (c <- const) yield {
      c -> (None, ctx.mkConst(c.name, z3sorts(c.typ)))
    }

Christian Müller's avatar
Christian Müller committed
146
    // if the index is defined, increment, otherwise use the expr (which is a constant f.e.)
Christian Müller's avatar
src  
Christian Müller committed
147
    val oldvars = for ((v, (i, e)) <- var_ctx) yield {
Christian Müller's avatar
Christian Müller committed
148 149
      if (i.isDefined) {
        val newi = i.get + vars.size
Christian Müller's avatar
src  
Christian Müller committed
150 151
        val newbound = ctx.mkBound(newi, e.getSort())
        v -> (Some(newi), newbound)
Christian Müller's avatar
Christian Müller committed
152
      } else {
Christian Müller's avatar
src  
Christian Müller committed
153
        v -> (i, e)
Christian Müller's avatar
Christian Müller committed
154 155
      }
    }
Christian Müller's avatar
src  
Christian Müller committed
156
    newexprs ++ constexprs ++ oldvars
Christian Müller's avatar
Christian Müller committed
157 158 159 160 161 162 163 164 165
  }

//  def mapback(e: Expr) = {
//    mapback(e)(Nil);
//  }

  def mapbackToVar(e:Expr)(implicit bindings:List[com.microsoft.z3.Symbol]): Var = {
    e match {
      case f2 if f2.isVar() => {
Christian Müller's avatar
src  
Christian Müller committed
166
        Var(bindings(bindings.size - 1 - e.getIndex).toString, f2.getSort.getName.toString)
Christian Müller's avatar
Christian Müller committed
167 168
      }
      case f2 if f2.isConst() => {
Christian Müller's avatar
src  
Christian Müller committed
169
        Var(f2.getFuncDecl().getName().toString, f2.getSort.getName.toString)
Christian Müller's avatar
Christian Müller committed
170 171 172 173 174 175 176 177 178
      }
    }
  }

  def mapback(e: Expr)(implicit bindings:List[com.microsoft.z3.Symbol] = Nil): Formula = {
    e match {
      case f2 if f2.isAnd() => And.make(f2.getArgs.map(mapback).toList)
      case f2 if f2.isOr() => Or.make(f2.getArgs.map(mapback).toList)
      case f2 if f2.isImplies() => Implies.make(f2.getArgs.map(mapback).toList)
Christian Müller's avatar
Christian Müller committed
179 180 181 182 183 184 185
      case f2 if f2.isEq() => {
        if (f2.getArgs.head.isBool) {
          Equiv.make(f2.getArgs.map(mapback).toList)
        } else {
          Equal.make(f2.getArgs.map(mapback).toList)
        }
      }
Christian Müller's avatar
Christian Müller committed
186 187 188 189 190 191 192 193 194 195 196 197
      case f2 if f2.isNot() => Neg(mapback(f2.getArgs().head))
      case f2 if f2.isTrue() => True
      case f2 if f2.isFalse() => False
      case f2 if f2.isQuantifier() => {
        val q = f2.asInstanceOf[com.microsoft.z3.Quantifier]
        val varnames = q.getBoundVariableNames
        val vartypes = q.getBoundVariableSorts

        val vars = varnames.zip(vartypes).map({
          case (name, typ) => Var(name.toString, typ.getName.toString)
        }).toList

Christian Müller's avatar
src  
Christian Müller committed
198
        val inner = mapback(q.getBody)(bindings ++ q.getBoundVariableNames.toList)
Christian Müller's avatar
Christian Müller committed
199 200 201 202 203 204
        if (q.isExistential) {
          Exists(vars, inner)
        } else {
          Forall(vars, inner)
        }
      }
Christian Müller's avatar
Christian Müller committed
205 206
      // There are no boolean constants, these are all function applications
      case f2 if f2.isVar || (f2.isConst && !f2.isBool) => {
Christian Müller's avatar
Christian Müller committed
207 208
        mapbackToVar(f2)
      }
Christian Müller's avatar
Christian Müller committed
209
      // Why is a constant also isApp?
Christian Müller's avatar
Christian Müller committed
210
      case f2 if f2.isApp() => {
Christian Müller's avatar
src  
Christian Müller committed
211 212 213 214
        val name = FunNameFromVar.unapply(f2.getFuncDecl().getName.toString)
        if (name.isEmpty) {
          logger.error(s"Failed parsing function name ${f2.getFuncDecl().getName}")
        }
Christian Müller's avatar
Christian Müller committed
215
        val params = f2.getArgs().map(mapbackToVar).toList
Christian Müller's avatar
src  
Christian Müller committed
216
        Fun(name.get._1, name.get._2, params)
Christian Müller's avatar
Christian Müller committed
217
      }
Christian Müller's avatar
Christian Müller committed
218

Christian Müller's avatar
Christian Müller committed
219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303
      case x => {
        logger.error(s"Error mapping expression back from Z3: Could not parse $x")
        Var("unknown")
      }
    }
  }


  def printModel(model: Model) = {

    val sb = new StringBuilder()

    val consts = model.getConstDecls()

    val vals = consts.map(_.apply())
    val typedvals = vals.groupBy(_.getSort)

    sb ++= "Universe:\n"
    for ((k, v) <- typedvals) {
      sb ++= s"Type $k: ${v.mkString(",")}\n"
    }

    sb ++= "Relations:\n"
    val sortedConsts = model.getConstDecls().sortBy(_.getName.toString())

    val (l1, l2) = sortedConsts.partition(s => {
      s.getName.toString() match {
        case FunFromVar(_) => true
        case _ => false
      }
    })

    val funs = l1.map(s => {
      val interp = model.getConstInterp(s)
      (FunFromVar.unapply(s.getName.toString).get, interp.toString)
    }) toList

    val grouped = funs.groupBy(f => (f._1.name, f._1.ind))

    // sort by name, path
    val entries = grouped.iterator.toList.sortBy(_._1)
    for (g <- entries) {
      sb ++= g._1._1 + "(" + g._1._2 + "):\n"// name

      // sort by value, variables
      val tuples = g._2 sortBy(e => (e._2, e._1.params.mkString(",") ))
      for ((fun, v) <- tuples) {
        sb ++= fun.params.mkString("(",",",")") + " = " + v + "\n"
      }
      sb ++= "\n"
    }

    sb ++= "\nNon-Relations:\n"

    // Rest of the consts
    for (f <- l2) {
      val interp = model.getConstInterp(f)

      val name = f.getName.toString match {
        case FunFromVar(fun) => fun
        case _ => f.getName.toString
      }

      sb ++= f.getName + " = " + interp.toString() + "\n"
    }

    val sortedFuns = model.getFuncDecls().sortBy(_.getName.toString())
    for (f <- sortedFuns) {
      sb ++= f.getName + f.getDomain.mkString("(",", ",")") + "\n"

      val interp = model.getFuncInterp(f)
      val entries = interp.getEntries
      for (e <- entries) {
        val args = for (arg <- e.getArgs.toList) yield {
          arg.getSort + " " + arg
        }
        sb ++= f.getName + args.mkString("(", ", ", ")") + " = " + e.getValue() + "\n"
      }

      val emptyargs = List.fill(f.getArity)("_")
      sb ++= f.getName + emptyargs.mkString("(", ", ", ")") + " = " + interp.getElse() + "\n\n"
    }
    sb.toString()
  }
}