diff --git a/cpp/include/wholememory/wholememory.h b/cpp/include/wholememory/wholememory.h index 7aac0e874..885dddd8e 100644 --- a/cpp/include/wholememory/wholememory.h +++ b/cpp/include/wholememory/wholememory.h @@ -83,9 +83,10 @@ enum wholememory_distributed_backend_t { /** * Initialize WholeMemory library * @param flags : reserved should be 0 + * @param wm_log_level : wholememory log level, the default level is "info" * @return : wholememory_error_code_t */ -wholememory_error_code_t wholememory_init(unsigned int flags); +wholememory_error_code_t wholememory_init(unsigned int flags, unsigned int wm_log_level = 3); /** * Finalize WholeMemory library diff --git a/cpp/src/wholememory/initialize.cpp b/cpp/src/wholememory/initialize.cpp index 2e80ab3c3..b7d1e54ac 100644 --- a/cpp/src/wholememory/initialize.cpp +++ b/cpp/src/wholememory/initialize.cpp @@ -17,6 +17,7 @@ #include #include +#include #include #include "communicator.hpp" @@ -32,7 +33,7 @@ static bool is_wm_init = false; static const std::string RAFT_NAME = "wholememory"; static cudaDeviceProp* device_props = nullptr; -wholememory_error_code_t init(unsigned int flags) noexcept +wholememory_error_code_t init(unsigned int flags, unsigned int wm_log_level) noexcept { try { std::unique_lock lock(mu); @@ -50,6 +51,7 @@ wholememory_error_code_t init(unsigned int flags) noexcept WM_CUDA_CHECK(cudaGetDeviceProperties(device_props + i, i)); } is_wm_init = true; + wholememory::set_log_level(std::pow(10, wm_log_level)); return WHOLEMEMORY_SUCCESS; } catch (raft::logic_error& logic_error) { WHOLEMEMORY_ERROR("init failed, logic_error=%s", logic_error.what()); diff --git a/cpp/src/wholememory/initialize.hpp b/cpp/src/wholememory/initialize.hpp index 2b9d0366b..77870f989 100644 --- a/cpp/src/wholememory/initialize.hpp +++ b/cpp/src/wholememory/initialize.hpp @@ -21,7 +21,7 @@ namespace wholememory { -wholememory_error_code_t init(unsigned int flags) noexcept; +wholememory_error_code_t init(unsigned int flags, unsigned int wm_log_level) noexcept; wholememory_error_code_t finalize() noexcept; diff --git a/cpp/src/wholememory/wholememory.cpp b/cpp/src/wholememory/wholememory.cpp index dbdce12e6..2f5f33a36 100644 --- a/cpp/src/wholememory/wholememory.cpp +++ b/cpp/src/wholememory/wholememory.cpp @@ -25,7 +25,10 @@ extern "C" { #endif -wholememory_error_code_t wholememory_init(unsigned int flags) { return wholememory::init(flags); } +wholememory_error_code_t wholememory_init(unsigned int flags, unsigned int wm_log_level) +{ + return wholememory::init(flags, wm_log_level); +} wholememory_error_code_t wholememory_finalize() { return wholememory::finalize(); } diff --git a/python/pylibwholegraph/examples/node_classfication.py b/python/pylibwholegraph/examples/node_classfication.py index 27b035fb9..fb77ffb88 100644 --- a/python/pylibwholegraph/examples/node_classfication.py +++ b/python/pylibwholegraph/examples/node_classfication.py @@ -130,7 +130,8 @@ def main_func(): wgth.get_world_size(), wgth.get_local_rank(), wgth.get_local_size(), - args.distributed_backend_type + args.distributed_backend_type, + args.log_level ) if args.use_cpp_ext: diff --git a/python/pylibwholegraph/pylibwholegraph/binding/wholememory_binding.pyx b/python/pylibwholegraph/pylibwholegraph/binding/wholememory_binding.pyx index 97d84c228..263fbd62f 100644 --- a/python/pylibwholegraph/pylibwholegraph/binding/wholememory_binding.pyx +++ b/python/pylibwholegraph/pylibwholegraph/binding/wholememory_binding.pyx @@ -71,7 +71,7 @@ cdef extern from "wholememory/wholememory.h": WHOLEMEMORY_DB_NONE "WHOLEMEMORY_DB_NONE" WHOLEMEMORY_DB_NCCL "WHOLEMEMORY_DB_NCCL" WHOLEMEMORY_DB_NVSHMEM "WHOLEMEMORY_DB_NVSHMEM" - cdef wholememory_error_code_t wholememory_init(unsigned int flags) + cdef wholememory_error_code_t wholememory_init(unsigned int flags, unsigned int wm_log_level) cdef wholememory_error_code_t wholememory_finalize() @@ -981,8 +981,8 @@ cdef class PyWholeMemoryUniqueID: def __dlpack_device__(self): return (kDLCPU, 0) -def init(unsigned int flags): - check_wholememory_error_code(wholememory_init(flags)) +def init(unsigned int flags, unsigned int wm_log_level = 3): + check_wholememory_error_code(wholememory_init(flags, wm_log_level)) def finalize(): check_wholememory_error_code(wholememory_finalize()) diff --git a/python/pylibwholegraph/pylibwholegraph/torch/common_options.py b/python/pylibwholegraph/pylibwholegraph/torch/common_options.py index 42746add8..14955305b 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/common_options.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/common_options.py @@ -9,7 +9,7 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License. +# limitations under the License.ß from argparse import ArgumentParser @@ -68,6 +68,12 @@ def add_training_options(argparser: ArgumentParser): default="nccl", help="distributed backend type, should be: nccl, nvshmem ", ) + argparser.add_argument( + "--log-level", + dest="log_level", + default="info", + help="Logging level of wholegraph, should be: trace, debug, info, warn, error" + ) def add_common_graph_options(argparser: ArgumentParser): diff --git a/python/pylibwholegraph/pylibwholegraph/torch/initialize.py b/python/pylibwholegraph/pylibwholegraph/torch/initialize.py index 3259a0e82..94ee74261 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/initialize.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/initialize.py @@ -18,12 +18,13 @@ from .comm import set_world_info, get_global_communicator, get_local_node_communicator, reset_communicators -def init(world_rank: int, world_size: int, local_rank: int, local_size: int): - wmb.init(0) +def init(world_rank: int, world_size: int, local_rank: int, local_size: int, wm_log_level="info"): + log_level_dic = {"error": 1, "warn": 2, "info": 3, "debug": 4, "trace": 5} + wmb.init(0, log_level_dic[wm_log_level]) set_world_info(world_rank, world_size, local_rank, local_size) -def init_torch_env(world_rank: int, world_size: int, local_rank: int, local_size: int): +def init_torch_env(world_rank: int, world_size: int, local_rank: int, local_size: int, wm_log_level): r"""Init WholeGraph environment for PyTorch. :param world_rank: world rank of current process :param world_size: world size of all processes @@ -44,7 +45,8 @@ def init_torch_env(world_rank: int, world_size: int, local_rank: int, local_size print("[WARNING] MASTER_PORT not set, resetting to 12335") os.environ["MASTER_PORT"] = "12335" - wmb.init(0) + log_level_dic = {"error": 1, "warn": 2, "info": 3, "debug": 4, "trace": 5} + wmb.init(0, log_level_dic[wm_log_level]) torch.set_num_threads(1) torch.cuda.set_device(local_rank) torch.distributed.init_process_group(backend="nccl", init_method="env://") @@ -52,7 +54,12 @@ def init_torch_env(world_rank: int, world_size: int, local_rank: int, local_size def init_torch_env_and_create_wm_comm( - world_rank: int, world_size: int, local_rank: int, local_size: int , distributed_backend_type="nccl" + world_rank: int, + world_size: int, + local_rank: int, + local_size: int, + distributed_backend_type="nccl", + wm_log_level="info" ): r"""Init WholeGraph environment for PyTorch and create single communicator for all ranks. :param world_rank: world rank of current process @@ -61,7 +68,7 @@ def init_torch_env_and_create_wm_comm( :param local_size: local size :return: global and local node Communicator """ - init_torch_env(world_rank, world_size, local_rank, local_size) + init_torch_env(world_rank, world_size, local_rank, local_size, wm_log_level) global_comm = get_global_communicator(distributed_backend_type) local_comm = get_local_node_communicator()