diff --git a/voyager-androidx/src/main/java/cafe/adriel/voyager/androidx/AndroidScreenLifecycleOwner.kt b/voyager-androidx/src/main/java/cafe/adriel/voyager/androidx/AndroidScreenLifecycleOwner.kt index cb59010f..d0a3502c 100644 --- a/voyager-androidx/src/main/java/cafe/adriel/voyager/androidx/AndroidScreenLifecycleOwner.kt +++ b/voyager-androidx/src/main/java/cafe/adriel/voyager/androidx/AndroidScreenLifecycleOwner.kt @@ -122,35 +122,27 @@ public class AndroidScreenLifecycleOwner private constructor() : override fun getDefaultViewModelProviderFactory(): ViewModelProvider.Factory { return SavedStateViewModelFactory( - (atomicContext.get()?.applicationContext as? Application), - this + application = atomicContext.get()?.applicationContext?.getApplication(), + owner = this ) } override fun registerLifecycleListener(outState: SavedState) { - val context = atomicContext.get() - if (context != null && context is LifecycleOwner) { + val activity = atomicContext.get()?.getActivity() + if (activity != null && activity is LifecycleOwner) { val observer = object : DefaultLifecycleObserver { override fun onStop(owner: LifecycleOwner) { performSave(outState) } } - val lifecycle = context.lifecycle + val lifecycle = activity.lifecycle lifecycle.addObserver(observer) deactivateLifecycleListener = { lifecycle.removeObserver(observer) } } } override fun getDefaultViewModelCreationExtras(): CreationExtras = MutableCreationExtras().apply { - var application: Application? = null - var context = atomicContext.get()?.applicationContext - while (context is ContextWrapper) { - if (context is Application) { - application = context - break - } - context = context.baseContext - } + val application = atomicContext.get()?.applicationContext?.getApplication() if (application != null) { set(AndroidViewModelFactory.APPLICATION_KEY, application) } @@ -162,6 +154,18 @@ public class AndroidScreenLifecycleOwner private constructor() : }*/ } + private tailrec fun Context.getActivity(): Activity? = when (this) { + is Activity -> this + is ContextWrapper -> baseContext.getActivity() + else -> null + } + + private tailrec fun Context.getApplication(): Application? = when (this) { + is Application -> this + is ContextWrapper -> baseContext.getApplication() + else -> null + } + public companion object { private val initEvents = arrayOf(