diff --git a/src/Lean/MetavarContext.lean b/src/Lean/MetavarContext.lean index 2b887f67432d..9390090903e9 100644 --- a/src/Lean/MetavarContext.lean +++ b/src/Lean/MetavarContext.lean @@ -336,6 +336,8 @@ structure MetavarContext where For more information about delayed abstraction, see the docstring for `DelayedMetavarAssignment`. -/ dAssignment : PersistentHashMap MVarId DelayedMetavarAssignment := {} +instance : Inhabited MetavarContext := ⟨{}⟩ + /-- A monad with a stateful metavariable context, defining `getMCtx` and `modifyMCtx`. -/ class MonadMCtx (m : Type → Type) where getMCtx : m MetavarContext @@ -358,15 +360,27 @@ abbrev setMCtx [MonadMCtx m] (mctx : MetavarContext) : m Unit := abbrev getLevelMVarAssignment? [Monad m] [MonadMCtx m] (mvarId : LMVarId) : m (Option Level) := return (← getMCtx).lAssignment.find? mvarId +@[export lean_get_lmvar_assignment] +def getLevelMVarAssignmentExp (m : MetavarContext) (mvarId : LMVarId) : Option Level := + m.lAssignment.find? mvarId + def MetavarContext.getExprAssignmentCore? (m : MetavarContext) (mvarId : MVarId) : Option Expr := m.eAssignment.find? mvarId +@[export lean_get_mvar_assignment] +def MetavarContext.getExprAssignmentExp (m : MetavarContext) (mvarId : MVarId) : Option Expr := + m.eAssignment.find? mvarId + def getExprMVarAssignment? [Monad m] [MonadMCtx m] (mvarId : MVarId) : m (Option Expr) := return (← getMCtx).getExprAssignmentCore? mvarId def MetavarContext.getDelayedMVarAssignmentCore? (mctx : MetavarContext) (mvarId : MVarId) : Option DelayedMetavarAssignment := mctx.dAssignment.find? mvarId +@[export lean_get_delayed_mvar_assignment] +def MetavarContext.getDelayedMVarAssignmentExp (mctx : MetavarContext) (mvarId : MVarId) : Option DelayedMetavarAssignment := + mctx.dAssignment.find? mvarId + def getDelayedMVarAssignment? [Monad m] [MonadMCtx m] (mvarId : MVarId) : m (Option DelayedMetavarAssignment) := return (← getMCtx).getDelayedMVarAssignmentCore? mvarId @@ -478,6 +492,10 @@ def hasAssignableMVar [Monad m] [MonadMCtx m] : Expr → m Bool def assignLevelMVar [MonadMCtx m] (mvarId : LMVarId) (val : Level) : m Unit := modifyMCtx fun m => { m with lAssignment := m.lAssignment.insert mvarId val } +@[export lean_assign_lmvar] +def assignLevelMVarExp (m : MetavarContext) (mvarId : LMVarId) (val : Level) : MetavarContext := + { m with lAssignment := m.lAssignment.insert mvarId val } + /-- Add `mvarId := x` to the metavariable assignment. This method does not check whether `mvarId` is already assigned, nor it checks whether @@ -487,6 +505,10 @@ This is a low-level API, and it is safer to use `isDefEq (mkMVar mvarId) x`. def _root_.Lean.MVarId.assign [MonadMCtx m] (mvarId : MVarId) (val : Expr) : m Unit := modifyMCtx fun m => { m with eAssignment := m.eAssignment.insert mvarId val } +@[export lean_assign_mvar] +def assignExp (m : MetavarContext) (mvarId : MVarId) (val : Expr) : MetavarContext := + { m with eAssignment := m.eAssignment.insert mvarId val } + /-- Add a delayed assignment for the given metavariable. You must make sure that the metavariable is not already assigned or delayed-assigned. @@ -516,6 +538,9 @@ To avoid this term eta-expanded term, we apply beta-reduction when instantiating This operation is performed at `instantiateExprMVars`, `elimMVarDeps`, and `levelMVarToParam`. -/ +@[extern "lean_instantiate_level_mvars"] +opaque instantiateLevelMVarsImp (mctx : MetavarContext) (l : Level) : MetavarContext × Level + partial def instantiateLevelMVars [Monad m] [MonadMCtx m] : Level → m Level | lvl@(Level.succ lvl₁) => return Level.updateSucc! lvl (← instantiateLevelMVars lvl₁) | lvl@(Level.max lvl₁ lvl₂) => return Level.updateMax! lvl (← instantiateLevelMVars lvl₁) (← instantiateLevelMVars lvl₂) @@ -531,6 +556,9 @@ partial def instantiateLevelMVars [Monad m] [MonadMCtx m] : Level → m Level | none => pure lvl | lvl => pure lvl +@[extern "lean_instantiate_expr_mvars"] +opaque instantiateExprMVarsImp (mctx : MetavarContext) (e : Expr) : MetavarContext × Expr + /-- instantiateExprMVars main function -/ partial def instantiateExprMVars [Monad m] [MonadMCtx m] [STWorld ω m] [MonadLiftT (ST ω) m] (e : Expr) : MonadCacheT ExprStructEq Expr m Expr := if !e.hasMVar then @@ -792,8 +820,6 @@ def localDeclDependsOnPred [Monad m] [MonadMCtx m] (localDecl : LocalDecl) (pf : namespace MetavarContext -instance : Inhabited MetavarContext := ⟨{}⟩ - @[export lean_mk_metavar_ctx] def mkMetavarContext : Unit → MetavarContext := fun _ => {} diff --git a/src/kernel/CMakeLists.txt b/src/kernel/CMakeLists.txt index 9b9c1e42d65c..2d21c3af4968 100644 --- a/src/kernel/CMakeLists.txt +++ b/src/kernel/CMakeLists.txt @@ -2,4 +2,4 @@ add_library(kernel OBJECT level.cpp expr.cpp expr_eq_fn.cpp for_each_fn.cpp replace_fn.cpp abstract.cpp instantiate.cpp local_ctx.cpp declaration.cpp environment.cpp type_checker.cpp init_module.cpp expr_cache.cpp equiv_manager.cpp quot.cpp -inductive.cpp trace.cpp) +inductive.cpp trace.cpp instantiate_mvars.cpp) diff --git a/src/kernel/instantiate_mvars.cpp b/src/kernel/instantiate_mvars.cpp new file mode 100644 index 000000000000..21ead4284057 --- /dev/null +++ b/src/kernel/instantiate_mvars.cpp @@ -0,0 +1,95 @@ +/* +Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. + +Authors: Leonardo de Moura +*/ +#include +#include "runtime/option_ref.h" +#include "kernel/instantiate.h" +#include "kernel/abstract.h" + +/* +This module is not used by the kernel. It just provides an efficient implementation of +`instantiateExprMVars` +*/ + +namespace lean { + +extern "C" object * lean_get_lmvar_assignment(obj_arg mctx, obj_arg mid); +extern "C" object * lean_assign_lmvar(obj_arg mctx, obj_arg mid, obj_arg val); + +typedef object_ref metavar_ctx; +void assign_lmvar(metavar_ctx & mctx, name const & mid, level const & l) { + object * r = lean_assign_lmvar(mctx.steal(), mid.to_obj_arg(), l.to_obj_arg()); + mctx.set_box(r); +} + +option_ref get_lmvar_assignment(metavar_ctx & mctx, name const & mid) { + return option_ref(lean_get_lmvar_assignment(mctx.to_obj_arg(), mid.to_obj_arg())); +} + +class instantiate_lmvar_fn { + metavar_ctx & m_mctx; + std::unordered_map m_cache; + + inline level cache(level const & l, level && r, bool shared) { + if (shared) { + m_cache.insert(mk_pair(l.raw(), r.raw())); + } + return r; + } +public: + instantiate_lmvar_fn(metavar_ctx & mctx):m_mctx(mctx) {} + level visit(level const & l) { + if (!has_mvar(l)) + return l; + bool shared = false; + if (is_shared(l)) { + auto it = m_cache.find(l.raw()); + if (it != m_cache.end()) { + return level(it->second, true); + } + shared = true; + } + switch (l.kind()) { + case level_kind::Succ: + return cache(l, update_succ(l, visit(succ_of(l))), shared); + case level_kind::Max: case level_kind::IMax: + return cache(l, update_max(l, visit(level_lhs(l)), visit(level_rhs(l))), shared); + case level_kind::Zero: case level_kind::Param: + lean_unreachable(); + case level_kind::MVar: { + option_ref r = get_lmvar_assignment(m_mctx, mvar_id(l)); + if (!r) { + return l; + } else { + level a(r.get_val()); + if (!has_mvar(a)) { + return a; + } else { + level a_new = visit(a); + if (!is_eqp(a, a_new)) { + assign_lmvar(m_mctx, mvar_id(l), a_new); + } + return a_new; + } + } + }} + } + level operator()(level const & l) { return visit(l); } +}; + +extern "C" LEAN_EXPORT object * lean_instantiate_level_mvars(object * m, object * l) { + metavar_ctx mctx(m); + level l_new = instantiate_lmvar_fn(mctx)(level(l)); + object * r = alloc_cnstr(0, 2, 0); + cnstr_set(r, 0, mctx.steal()); + cnstr_set(r, 1, l_new.steal()); + return r; +} + +extern "C" LEAN_EXPORT object * lean_instantiate_expr_mvars(object *, object *) { + lean_internal_panic("not implemented yet"); +} +} diff --git a/src/kernel/level.h b/src/kernel/level.h index 32bfade7c2e0..0739ab6e8fc2 100644 --- a/src/kernel/level.h +++ b/src/kernel/level.h @@ -82,6 +82,8 @@ inline bool operator!=(level const & l1, level const & l2) { return !operator==( struct level_hash { unsigned operator()(level const & n) const { return n.hash(); } }; struct level_eq { bool operator()(level const & n1, level const & n2) const { return n1 == n2; } }; +inline bool is_shared(level const & l) { return !is_exclusive(l.raw()); } + inline optional none_level() { return optional(); } inline optional some_level(level const & e) { return optional(e); } inline optional some_level(level && e) { return optional(std::forward(e)); } diff --git a/src/runtime/object_ref.h b/src/runtime/object_ref.h index c0209676f329..e543384dc307 100644 --- a/src/runtime/object_ref.h +++ b/src/runtime/object_ref.h @@ -35,6 +35,10 @@ class object_ref { s.m_obj = box(0); return *this; } + void set_box(object * o) { + lean_assert(is_scalar(m_obj)); + m_obj = o; + } object * raw() const { return m_obj; } object * steal() { object * r = m_obj; m_obj = box(0); return r; } object * to_obj_arg() const { inc(m_obj); return m_obj; } diff --git a/src/runtime/option_ref.h b/src/runtime/option_ref.h index 18c3f915ec4f..5191bca14e4d 100644 --- a/src/runtime/option_ref.h +++ b/src/runtime/option_ref.h @@ -28,6 +28,8 @@ class option_ref : public object_ref { explicit operator bool() const { return !is_scalar(raw()); } optional get() const { return *this ? some(static_cast(cnstr_get_ref(*this, 0))) : optional(); } + T get_val() const { lean_assert(*this); return static_cast(cnstr_get_ref(*this, 0)); } + /** \brief Structural equality. */ friend bool operator==(option_ref const & o1, option_ref const & o2) { return o1.get() == o2.get();