Skip to content

Commit

Permalink
refactor: port below construction to Lean
Browse files Browse the repository at this point in the history
this is the simplest of the constructions to be ported from C++ to Lean,
so I’ll PR this one first.

For validation I developed this in a separate repository at
https://github.com/nomeata/lean-constructions/tree/fad715e
and checked that all `.recOn` declarations found in Lean and Mathlib are
equivalent, up to

    def canon (e : Expr) : CoreM Expr := do
      Core.transform (← Core.betaReduce e) (pre := fun
        | .const n ls  => return .done (.const n (ls.map (·.normalize)))
        | .sort l => return .done (.sort l.normalize)
        | _ => return .continue)

It was not feasible to make them completely equal, because the kernel's
type inference code seem to optimize level expressions a bit less
aggressively, and beta-reduces less in inference.

The private helper functions about `PProd` can later move into their own
file, used by these constructions as well as the structural recursion
module.
  • Loading branch information
nomeata committed Jun 20, 2024
1 parent 1f732bb commit bf1073a
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 133 deletions.
17 changes: 1 addition & 16 deletions src/Lean/Meta/Constructions.lean
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@ import Lean.AuxRecursor
import Lean.AddDecl
import Lean.Meta.AppBuilder
import Lean.Meta.CompletionName
import Lean.Constructions.Below

namespace Lean

@[extern "lean_mk_rec_on"] opaque mkRecOnImp (env : Environment) (declName : @& Name) : Except KernelException Declaration
@[extern "lean_mk_cases_on"] opaque mkCasesOnImp (env : Environment) (declName : @& Name) : Except KernelException Declaration
@[extern "lean_mk_no_confusion_type"] opaque mkNoConfusionTypeCoreImp (env : Environment) (declName : @& Name) : Except KernelException Declaration
@[extern "lean_mk_no_confusion"] opaque mkNoConfusionCoreImp (env : Environment) (declName : @& Name) : Except KernelException Declaration
@[extern "lean_mk_below"] opaque mkBelowImp (env : Environment) (declName : @& Name) (ibelow : Bool) : Except KernelException Declaration
@[extern "lean_mk_brec_on"] opaque mkBRecOnImp (env : Environment) (declName : @& Name) (ind : Bool) : Except KernelException Declaration

open Meta
Expand All @@ -36,21 +36,6 @@ def mkCasesOn (declName : Name) : MetaM Unit := do
modifyEnv fun env => markAuxRecursor env name
modifyEnv fun env => addProtected env name

private def mkBelowOrIBelow (declName : Name) (ibelow : Bool) : MetaM Unit := do
let .inductInfo indVal ← getConstInfo declName | return
unless indVal.isRec do return
if ← isPropFormerType indVal.type then return

let decl ← ofExceptKernelException (mkBelowImp (← getEnv) declName ibelow)
let name := decl.definitionVal!.name
addDecl decl
setReducibleAttribute name
modifyEnv fun env => addToCompletionBlackList env name
modifyEnv fun env => addProtected env name

def mkBelow (declName : Name) : MetaM Unit := mkBelowOrIBelow declName true
def mkIBelow (declName : Name) : MetaM Unit := mkBelowOrIBelow declName false

private def mkBRecOrBInductionOn (declName : Name) (ind : Bool) : MetaM Unit := do
let .inductInfo indVal ← getConstInfo declName | return
unless indVal.isRec do return
Expand Down
157 changes: 157 additions & 0 deletions src/Lean/Meta/Constructions/Below.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
/-
Copyright (c) 2024 Lean FRO. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura, Joachim Breitner
-/
prelude
import Lean.Meta.InferType
import Lean.AuxRecursor
import Lean.AddDecl
import Lean.Meta.CompletionName

namespace Lean
open Meta

private def mkPUnit : Level → Expr
| .zero => .const ``True []
| lvl => .const ``PUnit [lvl]

private def mkPProd (e1 e2 : Expr) : MetaM Expr := do
let lvl1 ← getLevel e1
let lvl2 ← getLevel e2
if lvl1 matches .zero && lvl2 matches .zero then
return mkApp2 (.const `And []) e1 e2
else
return mkApp2 (.const ``PProd [lvl1, lvl2]) e1 e2

private def mkNProd (lvl : Level) (es : Array Expr) : MetaM Expr :=
es.foldrM (init := mkPUnit lvl) mkPProd

/--
If `minorType` is the type of a minor premies of a recursor, such as
```
(cons : (head : α) → (tail : List α) → (tail_hs : motive tail) → motive (head :: tail))
```
of `List.rec`, constructs the corresponding argument to `List.rec` in the construction
of `.below` definition; in this case
```
fun head tail tail_ih =>
PProd (PProd (motive tail) tail_ih) PUnit
```
of type
```
α → List α → Sort (max 1 u_1) → Sort (max 1 u_1)
```
The parameter `typeFormers` are the `motive`s.
-/
private def buildMinorPremise (rlvl : Level) (typeFormers : Array Expr) (minorType : Expr) : MetaM Expr :=
forallTelescope minorType fun minor_args _ => do go #[] minor_args.toList
where
ibelow := rlvl matches .zero
go (prods : Array Expr) : List Expr → MetaM Expr
| [] => mkNProd rlvl prods
| arg::args => do
let argType ← inferType arg
forallTelescope argType fun arg_args arg_type => do
if typeFormers.contains arg_type.getAppFn then
let name ← arg.fvarId!.getUserName
let type' ← forallTelescope argType fun args _ => mkForallFVars args (.sort rlvl)
withLocalDeclD name type' fun arg' => do
let snd ← mkForallFVars arg_args (mkAppN arg' arg_args)
let e' ← mkPProd argType snd
mkLambdaFVars #[arg'] (← go (prods.push e') args)
else
mkLambdaFVars #[arg] (← go prods args)

/--
Constructs the `.below` or `.ibelow` definition for a inductive predicate.
For example for the `List` type, it constructs,
```
@[reducible] protected def List.below.{u_1, u} : {α : Type u} →
{motive : List α → Sort u_1} → List α → Sort (max 1 u_1) :=
fun {α} {motive} t =>
List.rec PUnit (fun head tail tail_ih => PProd (PProd (motive tail) tail_ih) PUnit) t
```
and
```
@[reducible] protected def List.ibelow.{u} : {α : Type u} →
{motive : List α → Prop} → List α → Prop :=
fun {α} {motive} t =>
List.rec True (fun head tail tail_ih => (motive tail ∧ tail_ih) ∧ True) t
```
-/
private def mkBelowOrIBelow (indName : Name) (ibelow : Bool) : MetaM Unit := do
let indVal ← getConstInfoInduct indName
let recName := mkRecName indName
-- The construction follows the type of `ind.rec`
let .recInfo recVal ← getConstInfo recName
| throwError "{recName} not a .recInfo"
let lvl::lvls := recVal.levelParams.map (Level.param ·)
| throwError "recursor {recName} has no levelParams"
let lvlParam := recVal.levelParams.head!
-- universe parameter names of ibelow/below
let blvls :=
-- For ibelow we instantiate the first universe parameter of `.rec` to `.zero`
if ibelow then recVal.levelParams.tail!
else recVal.levelParams
let .some ilvl ← typeFormerTypeLevel indVal.type
| throwError "type {indVal.type} of inductive {indVal.name} not a type former?"

-- universe level of the resultant type
let rlvl : Level :=
if ibelow then
0
else if indVal.isReflexive then
if let .max 1 lvl := ilvl then
mkLevelMax' (.succ lvl) lvl
else
mkLevelMax' (.succ lvl) ilvl
else
mkLevelMax' 1 lvl

let refType :=
if ibelow then
recVal.type.instantiateLevelParams [lvlParam] [0]
else if indVal.isReflexive then
recVal.type.instantiateLevelParams [lvlParam] [lvl.succ]
else
recVal.type

let decl ← forallTelescope refType fun refArgs _ => do
assert! refArgs.size == indVal.numParams + recVal.numMotives + recVal.numMinors + indVal.numIndices + 1
let params : Array Expr := refArgs[:indVal.numParams]
let typeFormers : Array Expr := refArgs[indVal.numParams:indVal.numParams + recVal.numMotives]
let minors : Array Expr := refArgs[indVal.numParams + recVal.numMotives:indVal.numParams + recVal.numMotives + recVal.numMinors]
let remaining : Array Expr := refArgs[indVal.numParams + recVal.numMotives + recVal.numMinors:]

let mut val := .const recName (rlvl.succ :: lvls)
-- add parameters
val := mkAppN val params
-- add type formers
for typeFormer in typeFormers do
let arg ← forallTelescope (← inferType typeFormer) fun targs _ =>
mkLambdaFVars targs (.sort rlvl)
val := .app val arg
-- add minor premises
for minor in minors do
let arg ← buildMinorPremise rlvl typeFormers (← inferType minor)
val := .app val arg
-- add indices and major premise
val := mkAppN val remaining

-- All paramaters of `.rec` besides the `minors` become parameters of `.below`
let below_params := params ++ typeFormers ++ remaining
let type ← mkForallFVars below_params (.sort rlvl)
val ← mkLambdaFVars below_params val

let name := if ibelow then mkIBelowName indName else mkBelowName indName
mkDefinitionValInferrringUnsafe name blvls type val .abbrev

addDecl (.defnDecl decl)
setReducibleAttribute decl.name
modifyEnv fun env => markAuxRecursor env decl.name
modifyEnv fun env => addProtected env decl.name

def mkBelow (declName : Name) : MetaM Unit := mkBelowOrIBelow declName true
def mkIBelow (declName : Name) : MetaM Unit := mkBelowOrIBelow declName false
111 changes: 0 additions & 111 deletions src/library/constructions/brec_on.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,113 +32,6 @@ static optional<unsigned> is_typeformer_app(buffer<name> const & typeformer_name
return optional<unsigned>();
}

static declaration mk_below(environment const & env, name const & n, bool ibelow) {
local_ctx lctx;
constant_info ind_info = env.get(n);
inductive_val ind_val = ind_info.to_inductive_val();
name_generator ngen = mk_constructions_name_generator();
unsigned nparams = ind_val.get_nparams();
constant_info rec_info = env.get(mk_rec_name(n));
recursor_val rec_val = rec_info.to_recursor_val();
unsigned nminors = rec_val.get_nminors();
unsigned ntypeformers = rec_val.get_nmotives();
names lps = rec_info.get_lparams();
bool is_reflexive = ind_val.is_reflexive();
level lvl = mk_univ_param(head(lps));
levels lvls = lparams_to_levels(tail(lps));
names blvls; // universe parameter names of ibelow/below
level rlvl; // universe level of the resultant type
// The arguments of below (ibelow) are the ones in the recursor - minor premises.
// The universe we map to is also different (l+1 for below of reflexive types) and (0 fo ibelow).
expr ref_type;
expr Type_result;
if (ibelow) {
// we are eliminating to Prop
blvls = tail(lps);
rlvl = mk_level_zero();
ref_type = instantiate_lparam(rec_info.get_type(), param_id(lvl), mk_level_zero());
} else if (is_reflexive) {
blvls = lps;
rlvl = get_datatype_level(env, ind_info.get_type());
// if rlvl is of the form (max 1 l), then rlvl <- l
if (is_max(rlvl) && is_one(max_lhs(rlvl)))
rlvl = max_rhs(rlvl);
rlvl = mk_max(mk_succ(lvl), rlvl);
ref_type = instantiate_lparam(rec_info.get_type(), param_id(lvl), mk_succ(lvl));
} else {
// we can simplify the universe levels for non-reflexive datatypes
blvls = lps;
rlvl = mk_max(mk_level_one(), lvl);
ref_type = rec_info.get_type();
}
Type_result = mk_sort(rlvl);
buffer<expr> ref_args;
to_telescope(lctx, ngen, ref_type, ref_args);
lean_assert(ref_args.size() == nparams + ntypeformers + nminors + ind_val.get_nindices() + 1);

// args contains the below/ibelow arguments
buffer<expr> args;
buffer<name> typeformer_names;
// add parameters and typeformers
for (unsigned i = 0; i < nparams; i++)
args.push_back(ref_args[i]);
for (unsigned i = nparams; i < nparams + ntypeformers; i++) {
args.push_back(ref_args[i]);
typeformer_names.push_back(fvar_name(ref_args[i]));
}
// we ignore minor premises in below/ibelow
for (unsigned i = nparams + ntypeformers + nminors; i < ref_args.size(); i++)
args.push_back(ref_args[i]);

// We define below/ibelow using the recursor for this type
levels rec_lvls = cons(mk_succ(rlvl), lvls);
expr rec = mk_constant(rec_info.get_name(), rec_lvls);
for (unsigned i = 0; i < nparams; i++)
rec = mk_app(rec, args[i]);
// add type formers
for (unsigned i = nparams; i < nparams + ntypeformers; i++) {
buffer<expr> targs;
to_telescope(lctx, ngen, lctx.get_type(args[i]), targs);
rec = mk_app(rec, lctx.mk_lambda(targs, Type_result));
}
// add minor premises
for (unsigned i = nparams + ntypeformers; i < nparams + ntypeformers + nminors; i++) {
expr minor = ref_args[i];
expr minor_type = lctx.get_type(minor);
buffer<expr> minor_args;
minor_type = to_telescope(lctx, ngen, minor_type, minor_args);
buffer<expr> prod_pairs;
for (expr & minor_arg : minor_args) {
buffer<expr> minor_arg_args;
expr minor_arg_type = to_telescope(env, lctx, ngen, lctx.get_type(minor_arg), minor_arg_args);
if (is_typeformer_app(typeformer_names, minor_arg_type)) {
expr fst = lctx.get_type(minor_arg);
minor_arg = lctx.mk_local_decl(ngen, lctx.get_local_decl(minor_arg).get_user_name(), lctx.mk_pi(minor_arg_args, Type_result));
expr snd = lctx.mk_pi(minor_arg_args, mk_app(minor_arg, minor_arg_args));
type_checker tc(env, lctx);
prod_pairs.push_back(mk_pprod(tc, fst, snd, ibelow));
}
}
type_checker tc(env, lctx);
expr new_arg = foldr([&](expr const & a, expr const & b) { return mk_pprod(tc, a, b, ibelow); },
[&]() { return mk_unit(rlvl, ibelow); },
prod_pairs.size(), prod_pairs.data());
rec = mk_app(rec, lctx.mk_lambda(minor_args, new_arg));
}

// add indices and major premise
for (unsigned i = nparams + ntypeformers; i < args.size(); i++) {
rec = mk_app(rec, args[i]);
}

name below_name = ibelow ? name{n, "ibelow"} : name{n, "below"};
expr below_type = lctx.mk_pi(args, Type_result);
expr below_value = lctx.mk_lambda(args, rec);

return mk_definition_inferring_unsafe(env, below_name, blvls, below_type, below_value,
reducibility_hints::mk_abbreviation());
}

static declaration mk_brec_on(environment const & env, name const & n, bool ind) {
local_ctx lctx;
constant_info ind_info = env.get(n);
Expand Down Expand Up @@ -308,10 +201,6 @@ static declaration mk_brec_on(environment const & env, name const & n, bool ind)
reducibility_hints::mk_abbreviation());
}

extern "C" LEAN_EXPORT object * lean_mk_below(object * env, object * n, uint8 ibelow) {
return catch_kernel_exceptions<declaration>([&]() { return mk_below(environment(env), name(n, true), ibelow); });
}

extern "C" LEAN_EXPORT object * lean_mk_brec_on(object * env, object * n, uint8 ind) {
return catch_kernel_exceptions<declaration>([&]() { return mk_brec_on(environment(env), name(n, true), ind); });
}
Expand Down
6 changes: 0 additions & 6 deletions src/library/constructions/brec_on.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,6 @@ Author: Leonardo de Moura
#include "kernel/environment.h"

namespace lean {
/** \brief Given an inductive datatype \c n in \c env, return declaration for
<tt>n.below</tt> or <tt>.nibelow</tt> auxiliary construction for <tt>n.brec_on</t>
(aka below recursion on).
*/
declaration mk_below(environment const & env, name const & n, bool ibelow);

/** \brief Given an inductive datatype \c n in \c env, return declaration for
<tt>n.brec_on</tt> or <tt>n.binduction_on</tt>.
*/
Expand Down

0 comments on commit bf1073a

Please sign in to comment.