diff --git a/NEWS.md b/NEWS.md index f974f247e8ded..b502f5bdca35d 100644 --- a/NEWS.md +++ b/NEWS.md @@ -10,6 +10,9 @@ New language features and associativity as other arrow-like operators ([#36666]). * Compilation and type inference can now be enabled or disabled at the module level using the experimental macro `Base.Experimental.@compiler_options` ([#37041]). +* The library name passed to `ccall` or `@ccall` can now be an expression involving + global variables and function calls. The expression will be evaluated the first + time the `ccall` executes ([#36458]). Language changes ---------------- diff --git a/doc/src/manual/calling-c-and-fortran-code.md b/doc/src/manual/calling-c-and-fortran-code.md index ee5e2bc12cec3..628b6d9be44c8 100644 --- a/doc/src/manual/calling-c-and-fortran-code.md +++ b/doc/src/manual/calling-c-and-fortran-code.md @@ -878,7 +878,15 @@ it must be handled in other ways. ## Non-constant Function Specifications -A `(name, library)` function specification must be a constant expression. However, it is possible +In some cases, the exact name or path of the needed library is not known in advance and must +be computed at run time. To handle such cases, the library component of a `(name, library)` +specification can be a function call, e.g. `(:dgemm_, find_blas())`. The call expression will +be executed when the `ccall` itself is executed. However, it is assumed that the library +location does not change once it is determined, so the result of the call can be cached and +reused. Therefore, the number of times the expression executes is undefined, and returning +different values for multiple calls results in undefined behavior. + +If even more flexibility is needed, it is possible to use computed values as function names by staging through [`eval`](@ref) as follows: ``` diff --git a/src/ast.scm b/src/ast.scm index 230e477d1ee1a..e94e56c56de61 100644 --- a/src/ast.scm +++ b/src/ast.scm @@ -355,6 +355,9 @@ (define (ssavalue? e) (and (pair? e) (eq? (car e) 'ssavalue))) +(define (slot? e) + (and (pair? e) (eq? (car e) 'slot))) + (define (globalref? e) (and (pair? e) (eq? (car e) 'globalref))) @@ -439,6 +442,11 @@ (let ((x (cadr e))) (not (simple-atom? x))))) +(define (tuple-call? e) + (and (length> e 1) + (eq? (car e) 'call) + (equal? (cadr e) '(core tuple)))) + (define (eq-sym? a b) (or (eq? a b) (and (ssavalue? a) (ssavalue? b) (eqv? (cdr a) (cdr b))))) diff --git a/src/ccall.cpp b/src/ccall.cpp index fa7aa70564f2a..5a8fc45200eaf 100644 --- a/src/ccall.cpp +++ b/src/ccall.cpp @@ -73,7 +73,8 @@ static bool runtime_sym_gvs(jl_codegen_params_t &emission_context, const char *f static Value *runtime_sym_lookup( jl_codegen_params_t &emission_context, IRBuilder<> &irbuilder, - PointerType *funcptype, const char *f_lib, + jl_codectx_t *ctx, + PointerType *funcptype, const char *f_lib, jl_value_t *lib_expr, const char *f_name, Function *f, GlobalVariable *libptrgv, GlobalVariable *llvmgv, bool runtime_lib) @@ -106,16 +107,25 @@ static Value *runtime_sym_lookup( assert(f->getParent() != NULL); f->getBasicBlockList().push_back(dlsym_lookup); irbuilder.SetInsertPoint(dlsym_lookup); - Value *libname; - if (runtime_lib) { - libname = stringConstPtr(emission_context, irbuilder, f_lib); + Instruction *llvmf; + Value *nameval = stringConstPtr(emission_context, irbuilder, f_name); + if (lib_expr) { + jl_cgval_t libval = emit_expr(*ctx, lib_expr); + llvmf = irbuilder.CreateCall(prepare_call_in(jl_builderModule(irbuilder), jllazydlsym_func), + { boxed(*ctx, libval), nameval }); } else { - // f_lib is actually one of the special sentinel values - libname = ConstantExpr::getIntToPtr(ConstantInt::get(T_size, (uintptr_t)f_lib), T_pint8); + Value *libname; + if (runtime_lib) { + libname = stringConstPtr(emission_context, irbuilder, f_lib); + } + else { + // f_lib is actually one of the special sentinel values + libname = ConstantExpr::getIntToPtr(ConstantInt::get(T_size, (uintptr_t)f_lib), T_pint8); + } + llvmf = irbuilder.CreateCall(prepare_call_in(jl_builderModule(irbuilder), jldlsym_func), + { libname, nameval, libptrgv }); } - Value *llvmf = irbuilder.CreateCall(prepare_call_in(jl_builderModule(irbuilder), jldlsym_func), - { libname, stringConstPtr(emission_context, irbuilder, f_name), libptrgv }); StoreInst *store = irbuilder.CreateAlignedStore(llvmf, llvmgv, Align(sizeof(void*))); store->setAtomic(AtomicOrdering::Release); irbuilder.CreateBr(ccall_bb); @@ -124,21 +134,49 @@ static Value *runtime_sym_lookup( irbuilder.SetInsertPoint(ccall_bb); PHINode *p = irbuilder.CreatePHI(T_pvoidfunc, 2); p->addIncoming(llvmf_orig, enter_bb); - p->addIncoming(llvmf, dlsym_lookup); + p->addIncoming(llvmf, llvmf->getParent()); return irbuilder.CreateBitCast(p, funcptype); } static Value *runtime_sym_lookup( jl_codectx_t &ctx, - PointerType *funcptype, const char *f_lib, + PointerType *funcptype, const char *f_lib, jl_value_t *lib_expr, + const char *f_name, Function *f, + GlobalVariable *libptrgv, + GlobalVariable *llvmgv, bool runtime_lib) +{ + return runtime_sym_lookup(ctx.emission_context, ctx.builder, &ctx, funcptype, f_lib, lib_expr, + f_name, f, libptrgv, llvmgv, runtime_lib); +} + +static Value *runtime_sym_lookup( + jl_codectx_t &ctx, + PointerType *funcptype, const char *f_lib, jl_value_t *lib_expr, const char *f_name, Function *f) { GlobalVariable *libptrgv; GlobalVariable *llvmgv; - bool runtime_lib = runtime_sym_gvs(ctx.emission_context, f_lib, f_name, libptrgv, llvmgv); - libptrgv = prepare_global_in(jl_Module, libptrgv); + bool runtime_lib; + if (lib_expr) { + // for computed library names, generate a global variable to cache the function + // pointer just for this call site. + runtime_lib = true; + libptrgv = NULL; + std::string gvname = "libname_"; + gvname += f_name; + gvname += "_"; + gvname += std::to_string(globalUnique++); + Module *M = ctx.emission_context.shared_module(jl_LLVMContext); + llvmgv = new GlobalVariable(*M, T_pvoidfunc, false, + GlobalVariable::ExternalLinkage, + Constant::getNullValue(T_pvoidfunc), gvname); + } + else { + runtime_lib = runtime_sym_gvs(ctx.emission_context, f_lib, f_name, libptrgv, llvmgv); + libptrgv = prepare_global_in(jl_Module, libptrgv); + } llvmgv = prepare_global_in(jl_Module, llvmgv); - return runtime_sym_lookup(ctx.emission_context, ctx.builder, funcptype, f_lib, f_name, f, libptrgv, llvmgv, runtime_lib); + return runtime_sym_lookup(ctx, funcptype, f_lib, lib_expr, f_name, f, libptrgv, llvmgv, runtime_lib); } // Emit a "PLT" entry that will be lazily initialized @@ -169,7 +207,7 @@ static GlobalVariable *emit_plt_thunk( fname); BasicBlock *b0 = BasicBlock::Create(jl_LLVMContext, "top", plt); IRBuilder<> irbuilder(b0); - Value *ptr = runtime_sym_lookup(emission_context, irbuilder, funcptype, f_lib, f_name, plt, libptrgv, + Value *ptr = runtime_sym_lookup(emission_context, irbuilder, NULL, funcptype, f_lib, NULL, f_name, plt, libptrgv, llvmgv, runtime_lib); StoreInst *store = irbuilder.CreateAlignedStore(irbuilder.CreateBitCast(ptr, T_pvoidfunc), got, Align(sizeof(void*))); store->setAtomic(AtomicOrdering::Release); @@ -475,6 +513,7 @@ typedef struct { void (*fptr)(void); // if the argument is a constant pointer const char *f_name; // if the symbol name is known const char *f_lib; // if a library name is specified + jl_value_t *lib_expr; // expression to compute library path lazily jl_value_t *gcroot; } native_sym_arg_t; @@ -488,6 +527,24 @@ static void interpret_symbol_arg(jl_codectx_t &ctx, native_sym_arg_t &out, jl_va jl_value_t *ptr = static_eval(ctx, arg); if (ptr == NULL) { + if (jl_is_expr(arg) && ((jl_expr_t*)arg)->head == call_sym && jl_expr_nargs(arg) == 3 && + jl_is_globalref(jl_exprarg(arg,0)) && jl_globalref_mod(jl_exprarg(arg,0)) == jl_core_module && + jl_globalref_name(jl_exprarg(arg,0)) == jl_symbol("tuple")) { + // attempt to interpret a non-constant 2-tuple expression as (func_name, lib_name()), where + // `lib_name()` will be executed when first used. + jl_value_t *name_val = static_eval(ctx, jl_exprarg(arg,1)); + if (name_val && jl_is_symbol(name_val)) { + f_name = jl_symbol_name((jl_sym_t*)name_val); + out.lib_expr = jl_exprarg(arg, 2); + return; + } + else if (name_val && jl_is_string(name_val)) { + f_name = jl_string_data(name_val); + out.gcroot = name_val; + out.lib_expr = jl_exprarg(arg, 2); + return; + } + } jl_cgval_t arg1 = emit_expr(ctx, arg); jl_value_t *ptr_ty = arg1.typ; if (!jl_is_cpointer_type(ptr_ty)) { @@ -586,8 +643,11 @@ static jl_cgval_t emit_cglobal(jl_codectx_t &ctx, jl_value_t **args, size_t narg jl_printf(JL_STDERR,"WARNING: literal address used in cglobal for %s; code cannot be statically compiled\n", sym.f_name); } else { - if (imaging_mode) { - res = runtime_sym_lookup(ctx, cast(T_pint8), sym.f_lib, sym.f_name, ctx.f); + if (sym.lib_expr) { + res = runtime_sym_lookup(ctx, cast(T_pint8), NULL, sym.lib_expr, sym.f_name, ctx.f); + } + else if (imaging_mode) { + res = runtime_sym_lookup(ctx, cast(T_pint8), sym.f_lib, NULL, sym.f_name, ctx.f); res = ctx.builder.CreatePtrToInt(res, lrt); } else { @@ -597,7 +657,7 @@ static jl_cgval_t emit_cglobal(jl_codectx_t &ctx, jl_value_t **args, size_t narg if (!libsym || !jl_dlsym(libsym, sym.f_name, &symaddr, 0)) { // Error mode, either the library or the symbol couldn't be find during compiletime. // Fallback to a runtime symbol lookup. - res = runtime_sym_lookup(ctx, cast(T_pint8), sym.f_lib, sym.f_name, ctx.f); + res = runtime_sym_lookup(ctx, cast(T_pint8), sym.f_lib, NULL, sym.f_name, ctx.f); res = ctx.builder.CreatePtrToInt(res, lrt); } else { // since we aren't saving this code, there's no sense in @@ -1737,11 +1797,14 @@ jl_cgval_t function_sig_t::emit_a_ccall( else { assert(symarg.f_name != NULL); PointerType *funcptype = PointerType::get(functype, 0); - if (imaging_mode) { + if (symarg.lib_expr) { + llvmf = runtime_sym_lookup(ctx, funcptype, NULL, symarg.lib_expr, symarg.f_name, ctx.f); + } + else if (imaging_mode) { // vararg requires musttail, // but musttail is incompatible with noreturn. if (functype->isVarArg()) - llvmf = runtime_sym_lookup(ctx, funcptype, symarg.f_lib, symarg.f_name, ctx.f); + llvmf = runtime_sym_lookup(ctx, funcptype, symarg.f_lib, NULL, symarg.f_name, ctx.f); else llvmf = emit_plt(ctx, functype, attributes, cc, symarg.f_lib, symarg.f_name); } @@ -1751,7 +1814,7 @@ jl_cgval_t function_sig_t::emit_a_ccall( if (!libsym || !jl_dlsym(libsym, symarg.f_name, &symaddr, 0)) { // either the library or the symbol could not be found, place a runtime // lookup here instead. - llvmf = runtime_sym_lookup(ctx, funcptype, symarg.f_lib, symarg.f_name, ctx.f); + llvmf = runtime_sym_lookup(ctx, funcptype, symarg.f_lib, NULL, symarg.f_name, ctx.f); } else { // since we aren't saving this code, there's no sense in // putting anything complicated here: just JIT the function address diff --git a/src/codegen.cpp b/src/codegen.cpp index 6a4211a39f9e0..1fe9827d762ae 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -704,6 +704,12 @@ static const auto jldlsym_func = new JuliaFunction{ {T_pint8, T_pint8, PointerType::get(T_pint8, 0)}, false); }, nullptr, }; +static const auto jllazydlsym_func = new JuliaFunction{ + "jl_lazy_load_and_lookup", + [](LLVMContext &C) { return FunctionType::get(T_pvoidfunc, + {T_prjlvalue, T_pint8}, false); }, + nullptr, +}; static const auto jltypeassert_func = new JuliaFunction{ "jl_typeassert", [](LLVMContext &C) { return FunctionType::get(T_void, diff --git a/src/julia-syntax.scm b/src/julia-syntax.scm index b0f95053da807..a1ce6486d58c3 100644 --- a/src/julia-syntax.scm +++ b/src/julia-syntax.scm @@ -3906,11 +3906,9 @@ f(x) = yt(x) (cond ((eq? (car e) 'foreigncall) ;; NOTE: 2nd to 5th arguments of ccall must be left in place ;; the 1st should be compiled if an atom. - (append (if (or (atom? (cadr e)) - (let ((fptr (cadr e))) - (not (and (length> fptr 1) - (eq? (car fptr) 'call) - (equal? (cadr fptr) '(core tuple)))))) + (append (if (let ((fptr (cadr e))) + (or (atom? fptr) + (not (tuple-call? fptr)))) (compile-args (list (cadr e)) break-labels) (list (cadr e))) (list-head (cddr e) 4) @@ -4466,8 +4464,14 @@ f(x) = yt(x) `(gotoifnot ,(renumber-stuff (cadr e)) ,(get label-table (caddr e)))) ((eq? (car e) 'lambda) (renumber-lambda e 'none 0)) - (else (cons (car e) - (map renumber-stuff (cdr e)))))) + (else + (let ((e (cons (car e) + (map renumber-stuff (cdr e))))) + (if (and (eq? (car e) 'foreigncall) + (tuple-call? (cadr e)) + (expr-contains-p (lambda (x) (or (ssavalue? x) (slot? x))) (cadr e))) + (error "ccall function name and library expression cannot reference local variables")) + e)))) (let ((body (renumber-stuff (lam:body lam))) (vi (lam:vinfo lam))) (listify-lambda diff --git a/src/julia_internal.h b/src/julia_internal.h index ea97c67525295..066f05c8faddb 100644 --- a/src/julia_internal.h +++ b/src/julia_internal.h @@ -988,6 +988,7 @@ void *jl_get_library_(const char *f_lib, int throw_err) JL_NOTSAFEPOINT; #define jl_get_library(f_lib) jl_get_library_(f_lib, 1) JL_DLLEXPORT void *jl_load_and_lookup(const char *f_lib, const char *f_name, void **hnd) JL_NOTSAFEPOINT; +JL_DLLEXPORT void *jl_lazy_load_and_lookup(jl_value_t *lib_val, const char *f_name); JL_DLLEXPORT jl_value_t *jl_get_cfunction_trampoline( jl_value_t *fobj, jl_datatype_t *result, htable_t *cache, jl_svec_t *fill, void *(*init_trampoline)(void *tramp, void **nval), diff --git a/src/runtime_ccall.cpp b/src/runtime_ccall.cpp index 234084fbcdc0e..b9b3ed4dea41a 100644 --- a/src/runtime_ccall.cpp +++ b/src/runtime_ccall.cpp @@ -64,6 +64,23 @@ void *jl_load_and_lookup(const char *f_lib, const char *f_name, void **hnd) JL_N return ptr; } +// jl_load_and_lookup, but with library computed at run time on first call +extern "C" JL_DLLEXPORT +void *jl_lazy_load_and_lookup(jl_value_t *lib_val, const char *f_name) +{ + char *f_lib; + + if (jl_is_symbol(lib_val)) + f_lib = jl_symbol_name((jl_sym_t*)lib_val); + else if (jl_is_string(lib_val)) + f_lib = jl_string_data(lib_val); + else + jl_type_error("ccall", (jl_value_t*)jl_symbol_type, lib_val); + void *ptr; + jl_dlsym(jl_get_library(f_lib), f_name, &ptr, 1); + return ptr; +} + // miscellany std::string jl_get_cpu_name_llvm(void) { diff --git a/test/ccall.jl b/test/ccall.jl index 8108e8e8fb4f3..424fe80368855 100644 --- a/test/ccall.jl +++ b/test/ccall.jl @@ -1701,3 +1701,11 @@ end str = GC.@preserve buffer unsafe_string(Cwstring(pointer(buffer))) @test str == "α+β=15" end + +# issue #36458 +compute_lib_name() = "libcc" * "alltest" +ccall_lazy_lib_name(x) = ccall((:testUcharX, compute_lib_name()), Int32, (UInt8,), x % UInt8) +@test ccall_lazy_lib_name(0) == 0 +@test ccall_lazy_lib_name(3) == 1 +ccall_with_undefined_lib() = ccall((:time, xx_nOt_DeFiNeD_xx), Cint, (Ptr{Cvoid},), C_NULL) +@test_throws UndefVarError(:xx_nOt_DeFiNeD_xx) ccall_with_undefined_lib() diff --git a/test/syntax.jl b/test/syntax.jl index 76223f2a3591e..451d83a292ef3 100644 --- a/test/syntax.jl +++ b/test/syntax.jl @@ -1656,6 +1656,8 @@ end # #6080 @test Meta.lower(@__MODULE__, :(ccall(:a, Cvoid, (Cint,), &x))) == Expr(:error, "invalid syntax &x") +@test Meta.lower(@__MODULE__, :(f(x) = (y = x + 1; ccall((:a, y), Cvoid, ())))) == Expr(:error, "ccall function name and library expression cannot reference local variables") + @test_throws ParseError Meta.parse("x.'") @test_throws ParseError Meta.parse("0.+1")