From bf1073a5150383a54b05dcabcc00f9da2956a07d Mon Sep 17 00:00:00 2001 From: Joachim Breitner Date: Thu, 20 Jun 2024 16:51:04 +0200 Subject: [PATCH] refactor: port below construction to Lean MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit this is the simplest of the constructions to be ported from C++ to Lean, so I’ll PR this one first. For validation I developed this in a separate repository at https://github.com/nomeata/lean-constructions/tree/fad715e and checked that all `.recOn` declarations found in Lean and Mathlib are equivalent, up to def canon (e : Expr) : CoreM Expr := do Core.transform (← Core.betaReduce e) (pre := fun | .const n ls => return .done (.const n (ls.map (·.normalize))) | .sort l => return .done (.sort l.normalize) | _ => return .continue) It was not feasible to make them completely equal, because the kernel's type inference code seem to optimize level expressions a bit less aggressively, and beta-reduces less in inference. The private helper functions about `PProd` can later move into their own file, used by these constructions as well as the structural recursion module. --- src/Lean/Meta/Constructions.lean | 17 +-- src/Lean/Meta/Constructions/Below.lean | 157 +++++++++++++++++++++++++ src/library/constructions/brec_on.cpp | 111 ----------------- src/library/constructions/brec_on.h | 6 - 4 files changed, 158 insertions(+), 133 deletions(-) create mode 100644 src/Lean/Meta/Constructions/Below.lean diff --git a/src/Lean/Meta/Constructions.lean b/src/Lean/Meta/Constructions.lean index f40e0bbe04dc9..29e8ee5d32816 100644 --- a/src/Lean/Meta/Constructions.lean +++ b/src/Lean/Meta/Constructions.lean @@ -8,6 +8,7 @@ import Lean.AuxRecursor import Lean.AddDecl import Lean.Meta.AppBuilder import Lean.Meta.CompletionName +import Lean.Constructions.Below namespace Lean @@ -15,7 +16,6 @@ namespace Lean @[extern "lean_mk_cases_on"] opaque mkCasesOnImp (env : Environment) (declName : @& Name) : Except KernelException Declaration @[extern "lean_mk_no_confusion_type"] opaque mkNoConfusionTypeCoreImp (env : Environment) (declName : @& Name) : Except KernelException Declaration @[extern "lean_mk_no_confusion"] opaque mkNoConfusionCoreImp (env : Environment) (declName : @& Name) : Except KernelException Declaration -@[extern "lean_mk_below"] opaque mkBelowImp (env : Environment) (declName : @& Name) (ibelow : Bool) : Except KernelException Declaration @[extern "lean_mk_brec_on"] opaque mkBRecOnImp (env : Environment) (declName : @& Name) (ind : Bool) : Except KernelException Declaration open Meta @@ -36,21 +36,6 @@ def mkCasesOn (declName : Name) : MetaM Unit := do modifyEnv fun env => markAuxRecursor env name modifyEnv fun env => addProtected env name -private def mkBelowOrIBelow (declName : Name) (ibelow : Bool) : MetaM Unit := do - let .inductInfo indVal ← getConstInfo declName | return - unless indVal.isRec do return - if ← isPropFormerType indVal.type then return - - let decl ← ofExceptKernelException (mkBelowImp (← getEnv) declName ibelow) - let name := decl.definitionVal!.name - addDecl decl - setReducibleAttribute name - modifyEnv fun env => addToCompletionBlackList env name - modifyEnv fun env => addProtected env name - -def mkBelow (declName : Name) : MetaM Unit := mkBelowOrIBelow declName true -def mkIBelow (declName : Name) : MetaM Unit := mkBelowOrIBelow declName false - private def mkBRecOrBInductionOn (declName : Name) (ind : Bool) : MetaM Unit := do let .inductInfo indVal ← getConstInfo declName | return unless indVal.isRec do return diff --git a/src/Lean/Meta/Constructions/Below.lean b/src/Lean/Meta/Constructions/Below.lean new file mode 100644 index 0000000000000..50e4ff98ea82f --- /dev/null +++ b/src/Lean/Meta/Constructions/Below.lean @@ -0,0 +1,157 @@ +/- +Copyright (c) 2024 Lean FRO. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura, Joachim Breitner +-/ +prelude +import Lean.Meta.InferType +import Lean.AuxRecursor +import Lean.AddDecl +import Lean.Meta.CompletionName + +namespace Lean +open Meta + +private def mkPUnit : Level → Expr + | .zero => .const ``True [] + | lvl => .const ``PUnit [lvl] + +private def mkPProd (e1 e2 : Expr) : MetaM Expr := do + let lvl1 ← getLevel e1 + let lvl2 ← getLevel e2 + if lvl1 matches .zero && lvl2 matches .zero then + return mkApp2 (.const `And []) e1 e2 + else + return mkApp2 (.const ``PProd [lvl1, lvl2]) e1 e2 + +private def mkNProd (lvl : Level) (es : Array Expr) : MetaM Expr := + es.foldrM (init := mkPUnit lvl) mkPProd + +/-- +If `minorType` is the type of a minor premies of a recursor, such as +``` + (cons : (head : α) → (tail : List α) → (tail_hs : motive tail) → motive (head :: tail)) +``` +of `List.rec`, constructs the corresponding argument to `List.rec` in the construction +of `.below` definition; in this case +``` +fun head tail tail_ih => + PProd (PProd (motive tail) tail_ih) PUnit +``` +of type +``` +α → List α → Sort (max 1 u_1) → Sort (max 1 u_1) +``` +The parameter `typeFormers` are the `motive`s. +-/ +private def buildMinorPremise (rlvl : Level) (typeFormers : Array Expr) (minorType : Expr) : MetaM Expr := + forallTelescope minorType fun minor_args _ => do go #[] minor_args.toList +where + ibelow := rlvl matches .zero + go (prods : Array Expr) : List Expr → MetaM Expr + | [] => mkNProd rlvl prods + | arg::args => do + let argType ← inferType arg + forallTelescope argType fun arg_args arg_type => do + if typeFormers.contains arg_type.getAppFn then + let name ← arg.fvarId!.getUserName + let type' ← forallTelescope argType fun args _ => mkForallFVars args (.sort rlvl) + withLocalDeclD name type' fun arg' => do + let snd ← mkForallFVars arg_args (mkAppN arg' arg_args) + let e' ← mkPProd argType snd + mkLambdaFVars #[arg'] (← go (prods.push e') args) + else + mkLambdaFVars #[arg] (← go prods args) + +/-- +Constructs the `.below` or `.ibelow` definition for a inductive predicate. + +For example for the `List` type, it constructs, +``` +@[reducible] protected def List.below.{u_1, u} : {α : Type u} → + {motive : List α → Sort u_1} → List α → Sort (max 1 u_1) := +fun {α} {motive} t => + List.rec PUnit (fun head tail tail_ih => PProd (PProd (motive tail) tail_ih) PUnit) t +``` +and +``` +@[reducible] protected def List.ibelow.{u} : {α : Type u} → + {motive : List α → Prop} → List α → Prop := +fun {α} {motive} t => + List.rec True (fun head tail tail_ih => (motive tail ∧ tail_ih) ∧ True) t +``` +-/ +private def mkBelowOrIBelow (indName : Name) (ibelow : Bool) : MetaM Unit := do + let indVal ← getConstInfoInduct indName + let recName := mkRecName indName + -- The construction follows the type of `ind.rec` + let .recInfo recVal ← getConstInfo recName + | throwError "{recName} not a .recInfo" + let lvl::lvls := recVal.levelParams.map (Level.param ·) + | throwError "recursor {recName} has no levelParams" + let lvlParam := recVal.levelParams.head! + -- universe parameter names of ibelow/below + let blvls := + -- For ibelow we instantiate the first universe parameter of `.rec` to `.zero` + if ibelow then recVal.levelParams.tail! + else recVal.levelParams + let .some ilvl ← typeFormerTypeLevel indVal.type + | throwError "type {indVal.type} of inductive {indVal.name} not a type former?" + + -- universe level of the resultant type + let rlvl : Level := + if ibelow then + 0 + else if indVal.isReflexive then + if let .max 1 lvl := ilvl then + mkLevelMax' (.succ lvl) lvl + else + mkLevelMax' (.succ lvl) ilvl + else + mkLevelMax' 1 lvl + + let refType := + if ibelow then + recVal.type.instantiateLevelParams [lvlParam] [0] + else if indVal.isReflexive then + recVal.type.instantiateLevelParams [lvlParam] [lvl.succ] + else + recVal.type + + let decl ← forallTelescope refType fun refArgs _ => do + assert! refArgs.size == indVal.numParams + recVal.numMotives + recVal.numMinors + indVal.numIndices + 1 + let params : Array Expr := refArgs[:indVal.numParams] + let typeFormers : Array Expr := refArgs[indVal.numParams:indVal.numParams + recVal.numMotives] + let minors : Array Expr := refArgs[indVal.numParams + recVal.numMotives:indVal.numParams + recVal.numMotives + recVal.numMinors] + let remaining : Array Expr := refArgs[indVal.numParams + recVal.numMotives + recVal.numMinors:] + + let mut val := .const recName (rlvl.succ :: lvls) + -- add parameters + val := mkAppN val params + -- add type formers + for typeFormer in typeFormers do + let arg ← forallTelescope (← inferType typeFormer) fun targs _ => + mkLambdaFVars targs (.sort rlvl) + val := .app val arg + -- add minor premises + for minor in minors do + let arg ← buildMinorPremise rlvl typeFormers (← inferType minor) + val := .app val arg + -- add indices and major premise + val := mkAppN val remaining + + -- All paramaters of `.rec` besides the `minors` become parameters of `.below` + let below_params := params ++ typeFormers ++ remaining + let type ← mkForallFVars below_params (.sort rlvl) + val ← mkLambdaFVars below_params val + + let name := if ibelow then mkIBelowName indName else mkBelowName indName + mkDefinitionValInferrringUnsafe name blvls type val .abbrev + + addDecl (.defnDecl decl) + setReducibleAttribute decl.name + modifyEnv fun env => markAuxRecursor env decl.name + modifyEnv fun env => addProtected env decl.name + +def mkBelow (declName : Name) : MetaM Unit := mkBelowOrIBelow declName true +def mkIBelow (declName : Name) : MetaM Unit := mkBelowOrIBelow declName false diff --git a/src/library/constructions/brec_on.cpp b/src/library/constructions/brec_on.cpp index 912e803f52141..3211576336c58 100644 --- a/src/library/constructions/brec_on.cpp +++ b/src/library/constructions/brec_on.cpp @@ -32,113 +32,6 @@ static optional is_typeformer_app(buffer const & typeformer_name return optional(); } -static declaration mk_below(environment const & env, name const & n, bool ibelow) { - local_ctx lctx; - constant_info ind_info = env.get(n); - inductive_val ind_val = ind_info.to_inductive_val(); - name_generator ngen = mk_constructions_name_generator(); - unsigned nparams = ind_val.get_nparams(); - constant_info rec_info = env.get(mk_rec_name(n)); - recursor_val rec_val = rec_info.to_recursor_val(); - unsigned nminors = rec_val.get_nminors(); - unsigned ntypeformers = rec_val.get_nmotives(); - names lps = rec_info.get_lparams(); - bool is_reflexive = ind_val.is_reflexive(); - level lvl = mk_univ_param(head(lps)); - levels lvls = lparams_to_levels(tail(lps)); - names blvls; // universe parameter names of ibelow/below - level rlvl; // universe level of the resultant type - // The arguments of below (ibelow) are the ones in the recursor - minor premises. - // The universe we map to is also different (l+1 for below of reflexive types) and (0 fo ibelow). - expr ref_type; - expr Type_result; - if (ibelow) { - // we are eliminating to Prop - blvls = tail(lps); - rlvl = mk_level_zero(); - ref_type = instantiate_lparam(rec_info.get_type(), param_id(lvl), mk_level_zero()); - } else if (is_reflexive) { - blvls = lps; - rlvl = get_datatype_level(env, ind_info.get_type()); - // if rlvl is of the form (max 1 l), then rlvl <- l - if (is_max(rlvl) && is_one(max_lhs(rlvl))) - rlvl = max_rhs(rlvl); - rlvl = mk_max(mk_succ(lvl), rlvl); - ref_type = instantiate_lparam(rec_info.get_type(), param_id(lvl), mk_succ(lvl)); - } else { - // we can simplify the universe levels for non-reflexive datatypes - blvls = lps; - rlvl = mk_max(mk_level_one(), lvl); - ref_type = rec_info.get_type(); - } - Type_result = mk_sort(rlvl); - buffer ref_args; - to_telescope(lctx, ngen, ref_type, ref_args); - lean_assert(ref_args.size() == nparams + ntypeformers + nminors + ind_val.get_nindices() + 1); - - // args contains the below/ibelow arguments - buffer args; - buffer typeformer_names; - // add parameters and typeformers - for (unsigned i = 0; i < nparams; i++) - args.push_back(ref_args[i]); - for (unsigned i = nparams; i < nparams + ntypeformers; i++) { - args.push_back(ref_args[i]); - typeformer_names.push_back(fvar_name(ref_args[i])); - } - // we ignore minor premises in below/ibelow - for (unsigned i = nparams + ntypeformers + nminors; i < ref_args.size(); i++) - args.push_back(ref_args[i]); - - // We define below/ibelow using the recursor for this type - levels rec_lvls = cons(mk_succ(rlvl), lvls); - expr rec = mk_constant(rec_info.get_name(), rec_lvls); - for (unsigned i = 0; i < nparams; i++) - rec = mk_app(rec, args[i]); - // add type formers - for (unsigned i = nparams; i < nparams + ntypeformers; i++) { - buffer targs; - to_telescope(lctx, ngen, lctx.get_type(args[i]), targs); - rec = mk_app(rec, lctx.mk_lambda(targs, Type_result)); - } - // add minor premises - for (unsigned i = nparams + ntypeformers; i < nparams + ntypeformers + nminors; i++) { - expr minor = ref_args[i]; - expr minor_type = lctx.get_type(minor); - buffer minor_args; - minor_type = to_telescope(lctx, ngen, minor_type, minor_args); - buffer prod_pairs; - for (expr & minor_arg : minor_args) { - buffer minor_arg_args; - expr minor_arg_type = to_telescope(env, lctx, ngen, lctx.get_type(minor_arg), minor_arg_args); - if (is_typeformer_app(typeformer_names, minor_arg_type)) { - expr fst = lctx.get_type(minor_arg); - minor_arg = lctx.mk_local_decl(ngen, lctx.get_local_decl(minor_arg).get_user_name(), lctx.mk_pi(minor_arg_args, Type_result)); - expr snd = lctx.mk_pi(minor_arg_args, mk_app(minor_arg, minor_arg_args)); - type_checker tc(env, lctx); - prod_pairs.push_back(mk_pprod(tc, fst, snd, ibelow)); - } - } - type_checker tc(env, lctx); - expr new_arg = foldr([&](expr const & a, expr const & b) { return mk_pprod(tc, a, b, ibelow); }, - [&]() { return mk_unit(rlvl, ibelow); }, - prod_pairs.size(), prod_pairs.data()); - rec = mk_app(rec, lctx.mk_lambda(minor_args, new_arg)); - } - - // add indices and major premise - for (unsigned i = nparams + ntypeformers; i < args.size(); i++) { - rec = mk_app(rec, args[i]); - } - - name below_name = ibelow ? name{n, "ibelow"} : name{n, "below"}; - expr below_type = lctx.mk_pi(args, Type_result); - expr below_value = lctx.mk_lambda(args, rec); - - return mk_definition_inferring_unsafe(env, below_name, blvls, below_type, below_value, - reducibility_hints::mk_abbreviation()); -} - static declaration mk_brec_on(environment const & env, name const & n, bool ind) { local_ctx lctx; constant_info ind_info = env.get(n); @@ -308,10 +201,6 @@ static declaration mk_brec_on(environment const & env, name const & n, bool ind) reducibility_hints::mk_abbreviation()); } -extern "C" LEAN_EXPORT object * lean_mk_below(object * env, object * n, uint8 ibelow) { - return catch_kernel_exceptions([&]() { return mk_below(environment(env), name(n, true), ibelow); }); -} - extern "C" LEAN_EXPORT object * lean_mk_brec_on(object * env, object * n, uint8 ind) { return catch_kernel_exceptions([&]() { return mk_brec_on(environment(env), name(n, true), ind); }); } diff --git a/src/library/constructions/brec_on.h b/src/library/constructions/brec_on.h index 4eed3fc22ccae..5e0229453643d 100644 --- a/src/library/constructions/brec_on.h +++ b/src/library/constructions/brec_on.h @@ -8,12 +8,6 @@ Author: Leonardo de Moura #include "kernel/environment.h" namespace lean { -/** \brief Given an inductive datatype \c n in \c env, return declaration for - n.below or .nibelow auxiliary construction for n.brec_on - (aka below recursion on). -*/ -declaration mk_below(environment const & env, name const & n, bool ibelow); - /** \brief Given an inductive datatype \c n in \c env, return declaration for n.brec_on or n.binduction_on. */