Commit c7a2f347 authored by Axel Simon's avatar Axel Simon
Browse files

infer types by following SCCs

parent 24609920
......@@ -99,7 +99,113 @@ end = struct
List.map (fn e => AST.MARKexp {tree=e, span=s}) (decToSeqBitpat t)
| decToSeqBitpat (AST.NAMEDbitpat sym) = [AST.IDexp sym]
| decToSeqBitpat _ = []
(* define a traversal that returns a list of all top-level decls *)
fun topDecl (AST.MARKdecl {span, tree=t}) = topDecl t
| topDecl (AST.DECODEdecl dd) = topDecodeDecl dd
| topDecl (AST.LETRECdecl vd) = topLetrecDecl vd
| topDecl _ = []
and topDecodeDecl (v, _, _) = [(v, true)]
and topLetrecDecl (v, _, _) = [(v,false)]
fun toplevelDecls ast = List.concat
(List.map topDecl (ast : SpecAbstractTree.specification))
(*calculate gather all identifiers of a binding group in order to calculate SCCs*)
structure SCC = GraphSCCFn(ord_symid)
fun sccsSpecification ast =
let
val dom = SymSet.fromList (List.map #1 (toplevelDecls ast))
val sm = List.foldl (calleesDecl dom) SymMap.empty ast
fun getCallees s = SymSet.listItems (SymMap.lookup (sm,s))
in
SCC.topOrder' {roots = SymMap.listKeys sm,
follow = getCallees }
end
and calleesDecl dom (AST.MARKdecl {span, tree=t},sm) = calleesDecl dom (t,sm)
| calleesDecl dom (AST.DECODEdecl dd,sm) = calleesDecodeDecl (sm,dom) dd
| calleesDecl dom (AST.LETRECdecl vd,sm) = calleesLetrecDecl (sm,dom) vd
| calleesDecl dom (_,sm) = sm
and calleesDecodeDecl (sm,dom) (v, dps, Sum.INL e) =
let
val ids = List.foldl
(fn (dp, ids) => SymSet.union (calleesDecodepat dp, ids))
SymSet.empty dps
val ids = SymSet.union (calleesExp e, ids)
val ids = SymSet.intersection (ids, dom)
in
case SymMap.find (sm, v) of
NONE => SymMap.insert (sm, v, ids)
| SOME ids' => SymMap.insert (sm, v, SymSet.union (ids,ids'))
end
| calleesDecodeDecl (sm,dom) (v, dps, Sum.INR el) =
let
val ids = List.foldl
(fn (dp, ids) => SymSet.union (calleesDecodepat dp, ids))
SymSet.empty dps
val ids = List.foldl
(fn ((g,e), ids) => SymSet.union (
SymSet.union (calleesExp g, calleesExp e), ids))
ids el
val ids = SymSet.intersection (ids, dom)
in
case SymMap.find (sm, v) of
NONE => SymMap.insert (sm, v, ids)
| SOME ids' => SymMap.insert (sm, v, SymSet.union (ids,ids'))
end
and calleesLetrecDecl (sm,dom) (v, _, e) =
SymMap.insert (sm,v, SymSet.intersection (calleesExp e,dom))
and calleesExp (AST.MARKexp {span, tree=t}) = calleesExp t
| calleesExp (AST.LETRECexp (bs,e)) =
List.foldl
(fn ((_,_,body), ids) => SymSet.union (calleesExp body, ids))
(calleesExp e) bs
| calleesExp (AST.IFexp (e1,e2,e3)) = SymSet.union
(SymSet.union (calleesExp e1, calleesExp e2), calleesExp e3)
| calleesExp (AST.CASEexp (e,cs)) =
List.foldl
(fn ((_,e),ids) => SymSet.union (calleesExp e,ids))
(calleesExp e) cs
| calleesExp (AST.BINARYexp (e1,v,e2)) = SymSet.union
(calleesExp e1, SymSet.union (calleesExp e2,calleesInfixop v))
| calleesExp (AST.APPLYexp (e,es)) =
List.foldl (fn (e,ids) => SymSet.union (calleesExp e,ids))
(calleesExp e) es
| calleesExp (AST.RECORDexp fs) =
List.foldl (fn ((_,e),ids) => SymSet.union (calleesExp e, ids))
SymSet.empty fs
| calleesExp (AST.UPDATEexp fs) =
List.foldl (fn ((_,eOpt),ids) => case eOpt of
SOME e => SymSet.union (calleesExp e, ids)
| NONE => ids)
SymSet.empty fs
| calleesExp (AST.SEQexp ss) =
List.foldl (fn (s,ids) => SymSet.union (calleesSeqexp s, ids))
SymSet.empty ss
| calleesExp (AST.IDexp v) = SymSet.singleton v
| calleesExp (AST.FNexp (_,e)) = calleesExp e
| calleesExp _ = SymSet.empty
and calleesInfixop (AST.MARKinfixop {span, tree=t}) = calleesInfixop t
| calleesInfixop (AST.OPinfixop id) = SymSet.singleton id
and calleesSeqexp (AST.MARKseqexp {span, tree=t}) = calleesSeqexp t
| calleesSeqexp (AST.ACTIONseqexp e) = calleesExp e
| calleesSeqexp (AST.BINDseqexp (_,e)) = calleesExp e
and calleesDecodepat (AST.MARKdecodepat {span, tree=t}) = calleesDecodepat t
| calleesDecodepat (AST.TOKENdecodepat tp) = calleesTokpat tp
| calleesDecodepat (AST.BITdecodepat bps) =
List.foldl (fn (bp,ids) => SymSet.union (calleesBitpat bp, ids))
SymSet.empty bps
and calleesBitpat (AST.MARKbitpat {span, tree=t}) = calleesBitpat t
| calleesBitpat (AST.NAMEDbitpat v) = SymSet.singleton v
| calleesBitpat _ = SymSet.empty
and calleesTokpat (AST.MARKtokpat {span, tree=t}) = calleesTokpat t
| calleesTokpat (AST.NAMEDtokpat v) = SymSet.singleton v
| calleesTokpat _ = SymSet.empty
fun typeInferencePass (errStrm, ti : TI.type_info, ast) = let
val sm = ref ([] : symbol_types)
val { tsynDefs, typeDefs, conParents} = ti
......@@ -113,25 +219,19 @@ fun typeInferencePass (errStrm, ti : TI.type_info, ast) = let
val bindSymId = SymbolTable.lookup(!SymbolTables.varTable, Atom.atom ">>")
val bindASymId = SymbolTable.lookup(!SymbolTables.varTable, Atom.atom ">>=")
fun reportError conv (_, env) {span=s, tree=t} =
conv ({span=s},env) t
fun reportError conv ({span,component=comp}, env) {span=s, tree=t} =
conv ({span=s,component=comp},env) t
handle (S.UnificationFailure str) =>
(Error.errorAt (errStrm, s, [str]); raise TypeError)
val reportBadSizes = List.app (fn (s,str) => Error.errorAt (errStrm, s, [str]))
fun getSpan {span = s} = s
(* define a first traversal that creates a group of all top-level decls *)
fun topDecl (AST.MARKdecl {span, tree=t}) = topDecl t
| topDecl (AST.DECODEdecl dd) = topDecodeDecl dd
| topDecl (AST.LETRECdecl vd) = topLetrecDecl vd
| topDecl _ = []
and topDecodeDecl (v, _, _) = [(v, true)]
and topLetrecDecl (v, _, _) = [(v,false)]
fun getSpan {span = s,component} = s
fun hasSymbol ({span, component = SCC.SIMPLE n},s) = SymbolTable.eq_symid (s,n)
| hasSymbol ({span, component = SCC.RECURSIVE ns},s) =
List.exists (fn n => SymbolTable.eq_symid (s,n)) ns
(* define a second traversal that is a full inference of the tree *)
(* define a traversal that is a full inference of the tree *)
(*local helper function to infer types for a binding group*)
val maxIter = 2
val maxIter = 1
fun calcSubset (printWarn,sym,env) =
let
......@@ -141,7 +241,7 @@ fun typeInferencePass (errStrm, ti : TI.type_info, ast) = let
val oldCtxt = E.getCtxt env
val env = List.foldl E.pushFunction env fs
val envFun = E.pushSymbol (sym, s, env)
val envFun = E.pushSymbol (sym, s, false, env)
(*val _ = TextIO.print ("pushed instance " ^ SymbolTable.getString(!SymbolTables.varTable, sym) ^ " symbol:\n" ^ E.topToString envFun)*)
val envCall = E.pushUsage (sym, s, !sm, env)
(*val _ = TextIO.print ("pushed usage:\n" ^ E.topToString envCall)*)
......@@ -188,13 +288,13 @@ fun typeInferencePass (errStrm, ti : TI.type_info, ast) = let
List.all (checkUsage sym) usages
end
fun calcSubsets printWarn env =
fun calcSubsets printWarn (syms,env) =
List.foldl (fn (sym,unstable) =>
(if calcSubset (printWarn,sym,env) then unstable else
E.SymbolSet.add (unstable, sym))
handle (S.UnificationFailure str) =>
E.SymbolSet.add (unstable, sym)
) E.SymbolSet.empty (E.getGroupSyms env)
) E.SymbolSet.empty syms
fun calcIteration (sym, env) =
let
......@@ -204,7 +304,7 @@ fun typeInferencePass (errStrm, ti : TI.type_info, ast) = let
val oldCtxt = E.getCtxt env
val env = List.foldl E.pushFunction env fs
val envFun = E.pushSymbol (sym, s, env)
val envFun = E.pushSymbol (sym, s, false, env)
(*val _ = if SymbolTable.toInt sym = 95 then TextIO.print ("pushed instance " ^ SymbolTable.getString(!SymbolTables.varTable, sym) ^ " symbol:\n" ^ E.kappaToString envFun)
else ()*)
val envCall = E.pushUsage (sym, s, !sm, env)
......@@ -244,11 +344,11 @@ fun typeInferencePass (errStrm, ti : TI.type_info, ast) = let
List.foldl checkUsage env usages
end
fun calcFixpoints curIter env =
case calcSubsets (curIter>0) env of unstable =>
fun calcFixpoints curIter (syms,env) =
case calcSubsets (curIter>0) (syms,env) of unstable =>
if E.SymbolSet.isEmpty unstable then env else
if curIter<maxIter then
calcFixpoints (curIter+1) (
calcFixpoints (curIter+1) (syms,
List.foldl calcIteration env (E.SymbolSet.listItems unstable)
)
else
......@@ -258,7 +358,7 @@ fun typeInferencePass (errStrm, ti : TI.type_info, ast) = let
let
val sStr = SymbolTable.getString
(!SymbolTables.varTable, sym)
val env = E.pushSymbol (sym, SymbolTable.noSpan, env)
val env = E.pushSymbol (sym, SymbolTable.noSpan, false, env)
val (sType, si) = E.kappaToStringSI (env, si)
in
(res ^ pre ^ sStr ^ " : " ^ sType, ", ", si)
......@@ -285,7 +385,7 @@ fun typeInferencePass (errStrm, ti : TI.type_info, ast) = let
let
val sStr = SymbolTable.getString (!SymbolTables.varTable, sym)
val s = SymbolTable.getSpan(!SymbolTables.varTable, sym)
val envT = E.pushSymbol (sym, SymbolTable.noSpan, env)
val envT = E.pushSymbol (sym, SymbolTable.noSpan, false, env)
val sType = E.kappaToString envT
in
(Error.errorAt (errStrm, s, [
......@@ -297,6 +397,7 @@ fun typeInferencePass (errStrm, ti : TI.type_info, ast) = let
end
val calcFixpoint = calcFixpoint 0
(*local helper function to infer types for a binding group*)
fun infRhs (st,env) (sym, dec, guard, args, rhs) =
let
(*val _ = TextIO.print ("checking binding " ^ SymbolTable.getString(!SymbolTables.varTable, sym) ^ "\n")*)
......@@ -353,10 +454,12 @@ fun typeInferencePass (errStrm, ti : TI.type_info, ast) = let
env
end
| infDecl stenv (AST.DECODEdecl dd) = infDecodedecl stenv dd
| infDecl (st,env) (AST.LETRECdecl (v,l,e)) = infBinding (st,env) (v, l, e)
| infDecl (st,env) (AST.LETRECdecl (v,l,e)) =
if hasSymbol (st,v) then infBinding (st,env) (v, l, e) else env
| infDecl (st,env) _ = env
and infDecodedecl (st,env) (v, l, Sum.INL e) =
if not (hasSymbol (st,v)) then env else
let
val env = E.pushFunctionOrTop (v,env)
val envRhs = E.popKappa env
......@@ -373,6 +476,7 @@ fun typeInferencePass (errStrm, ti : TI.type_info, ast) = let
env
end
| infDecodedecl (st,env) (v, l, Sum.INR el) =
if not (hasSymbol (st,v)) then env else
let
val env = E.pushFunctionOrTop (v,env)
val env = List.foldl
......@@ -565,7 +669,7 @@ fun typeInferencePass (errStrm, ti : TI.type_info, ast) = let
| infExp stenv (AST.SEQexp l) = infSeqexp stenv l
| infExp (st,env) (AST.IDexp v) =
let
val env = E.pushSymbol (v, getSpan st, env)
val env = E.pushSymbol (v, getSpan st, hasSymbol (st,v), env)
(*val _ = TextIO.print ("**** after pushing symbol " ^ SymbolTable.getString(!SymbolTables.varTable, v) ^ ":\n" ^ E.topToString env)*)
in
env
......@@ -609,7 +713,7 @@ fun typeInferencePass (errStrm, ti : TI.type_info, ast) = let
AST.ACTIONseqexp e => (bindSymId, NONE, e)
| AST.BINDseqexp (v,e) => (bindASymId, SOME v, e)
| _ => raise TypeError
val envFun = E.pushSymbol (bind, getSpan st, env)
val envFun = E.pushSymbol (bind, getSpan st, hasSymbol (st,bind), env)
val envArg = infExp (st,env) e
val envArgRes = E.pushTop envArg
val envArgRes = E.reduceToFunction (envArgRes,1)
......@@ -683,11 +787,11 @@ fun typeInferencePass (errStrm, ti : TI.type_info, ast) = let
and infBitpat stenv (AST.MARKbitpat m) = reportError infBitpat stenv m
| infBitpat (st,env) (AST.BITSTRbitpat str) = (0,env)
| infBitpat (st,env) (AST.NAMEDbitpat v) =
(1, E.pushSymbol (v, getSpan st, env))
(1, E.pushSymbol (v, getSpan st, hasSymbol (st,v), env))
| infBitpat (st,env) (AST.BITVECbitpat (v,s)) =
let
val env = E.pushLambdaVar (v,env)
val envVar = E.pushSymbol (v, getSpan st, env)
val envVar = E.pushSymbol (v, getSpan st, hasSymbol (st,v), env)
val envWidth = E.pushType (false, VEC (CONST (getBitpatLitLength s)), env)
val env = E.meet (envVar, envWidth)
val env = E.popKappa env
......@@ -702,7 +806,7 @@ fun typeInferencePass (errStrm, ti : TI.type_info, ast) = let
val (n,env) = infPat (st,env) p
(*val _ = TextIO.print ("**** after pat:\n" ^ E.topToString env)*)
val envScru = E.popKappa env
val envScru = E.pushSymbol (caseExpSymId, SymbolTable.noSpan, envScru)
val envScru = E.pushSymbol (caseExpSymId, SymbolTable.noSpan, false, envScru)
(*val _ = TextIO.print ("**** after case dup:\n" ^ E.topToString envScru)*)
val env = E.meetFlow (env, envScru)
handle S.UnificationFailure str =>
......@@ -774,11 +878,36 @@ fun typeInferencePass (errStrm, ti : TI.type_info, ast) = let
val primEnv = E.primitiveEnvironment (Primitives.getSymbolTypes (),
SizeConstraint.fromList Primitives.primitiveSizeConstraints)
val toplevelEnv = E.pushGroup
(List.concat
(List.map topDecl (ast : SpecAbstractTree.specification))
, primEnv)
val toplevelEnv = E.pushGroup (toplevelDecls ast, primEnv)
val sccs = List.rev (sccsSpecification ast)
(*fun prComp (SCC.SIMPLE s) =
SymbolTable.getString(!SymbolTables.varTable, s) ^ "\n"
| prComp (SCC.RECURSIVE ss) = "(" ^ #1 (List.foldl (fn (s,(str,sep)) =>
(str ^ sep ^
SymbolTable.getString(!SymbolTables.varTable, s), ","))
("","") ss) ^ ")\n"
val _ = TextIO.print ("SCCs:\n" ^ List.foldl (fn (c,str) => str ^ prComp c) "" sccs)*)
(*val _ = TextIO.print ("toplevel environment:\n" ^ E.toString toplevelEnv)*)
val toplevelEnv = List.foldl (fn (comp,env) =>
let
(*val _ = TextIO.print ("checking component " ^ prComp comp)*)
val env = List.foldl (fn (d,env) =>
infDecl ({span = SymbolTable.noSpan,
component = comp},env) d
handle TypeError => env) env ast
val env = case comp of
SCC.SIMPLE _ => env
| SCC.RECURSIVE syms => calcFixpoints (syms, env)
handle TypeError => env
in
env
end
) toplevelEnv sccs
(*add usage sites for all exported functions*)
fun checkExports _ (AST.MARKdecl {span=s, tree=t},env) = checkExports s (t,env)
| checkExports s (AST.EXPORTdecl vs,env) =
......@@ -786,15 +915,7 @@ fun typeInferencePass (errStrm, ti : TI.type_info, ast) = let
| checkExports s (_,env) = env
val toplevelEnv = List.foldl (checkExports SymbolTable.noSpan) toplevelEnv ast
(*val _ = TextIO.print ("toplevel environment:\n" ^ E.toString toplevelEnv)*)
val toplevelEnv = List.foldl (fn (d,env) =>
infDecl ({span = SymbolTable.noSpan},env) d
handle TypeError => env
) toplevelEnv (ast : SpecAbstractTree.specification)
(*val toplevelEnv = calcFixpoints toplevelEnv
handle TypeError => toplevelEnv*)
val _ = TextIO.print ("toplevel environment:\n" ^ E.topToString toplevelEnv)
(*val _ = TextIO.print ("toplevel environment:\n" ^ E.topToString toplevelEnv)*)
val (badSizes, primEnv) = E.popGroup (toplevelEnv, false)
val _ = reportBadSizes badSizes
......
......@@ -51,11 +51,10 @@ structure Environment : sig
val genConstructorFlow : (bool * environment) -> environment
(*given an occurrence of a symbol at a position, push its type onto the
stack and return if an instance of this type must be used; arguments are
the symbol to look up, the position it occurred and a list of symbols that
denote the current context/function (the latter is ignored if the symbol
already has a type) *)
val pushSymbol : VarInfo.symid * Error.span * environment -> environment
stack; arguments are the symbol to look up, the position it occurred and a
Boolean flag indicating if this usage should be recorded (True) or if an
existing type should be used (False) *)
val pushSymbol : VarInfo.symid * Error.span * bool * environment -> environment
val getUsages : VarInfo.symid * environment -> Error.span list
......@@ -100,8 +99,8 @@ structure Environment : sig
(*stack: [...,t] -> [...] and type of function f is set to t*)
val popToFunction : VarInfo.symid * environment -> environment
(*the type of function f is unset*)
val clearFunction : VarInfo.symid * environment -> environment
(*unset the type of function f, returns false if the type was already unset*)
val clearFunction : VarInfo.symid * environment -> (bool * environment)
(*add the given function symbol to the current context*)
val pushFunction : VarInfo.symid * environment -> environment
......@@ -825,7 +824,7 @@ end = struct
end
| _ => raise InferenceBug
fun pushSymbol (sym, span, env) =
fun pushSymbol (sym, span, recordUsage, env) =
(case Scope.lookup (sym,env) of
(_, SIMPLE {ty = t}) =>
let
......@@ -861,13 +860,17 @@ end = struct
uses = SpanMap.insert (uses, span, (ctxt, t))}, cons)
| action _ = raise InferenceBug
val env =
if TVar.isEmpty (TVar.intersection (decVars, SC.getVarset (Scope.getSize state)))
if not recordUsage andalso TVar.isEmpty (TVar.intersection (decVars, SC.getVarset (Scope.getSize state)))
then env
else Scope.update (sym, action, env)
in
Scope.wrap (KAPPA {ty = t}, env)
end
| (_, COMPOUND {ty = NONE, width, uses}) =>
if not recordUsage then
(TextIO.print ("need to push usage for " ^
SymbolTable.getString(!SymbolTables.varTable, sym) ^ "\n");
raise InferenceBug) else
(case SpanMap.find (uses, span) of
SOME (_,t) => Scope.wrap (KAPPA {ty = t}, env)
| NONE =>
......@@ -1229,11 +1232,15 @@ end = struct
fun clearFunction (sym, env) =
let
val unsetRef = ref false
fun resetType (COMPOUND {ty = SOME _, width, uses}, cons) =
(unsetRef := true;
(COMPOUND {ty = NONE, width = width, uses = uses}, cons))
| resetType (COMPOUND {ty = NONE, width, uses}, cons) =
(COMPOUND {ty = NONE, width = width, uses = uses}, cons)
| resetType _ = raise InferenceBug
in
Scope.update (sym, resetType, env)
(!unsetRef, Scope.update (sym, resetType, env))
end
fun pushFunction (sym, (scs,state)) =
......@@ -1269,7 +1276,7 @@ end = struct
(COMPOUND {ty = SOME (freshVar (),BD.empty), width = width, uses = uses}, cons)
| setType _ = raise InferenceBug
val env = Scope.update (sym, setType, env)
val env = pushSymbol (sym, SymbolTable.noSpan, env)
val env = pushSymbol (sym, SymbolTable.noSpan, false, env)
val env = pushFunction (sym,env)
in
env
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment