Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Warn about reloading dependencies after downloading models #13081

Merged
merged 10 commits into from
Nov 10, 2023
30 changes: 29 additions & 1 deletion spacy/cli/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,14 @@

from .. import about
from ..errors import OLD_MODEL_SHORTCUTS
from ..util import get_minor_version, is_package, is_prerelease_version, run_command
from ..util import (
get_minor_version,
is_in_interactive,
is_in_jupyter,
is_package,
is_prerelease_version,
run_command,
)
from ._util import SDIST_SUFFIX, WHEEL_SUFFIX, Arg, Opt, app


Expand Down Expand Up @@ -77,6 +84,27 @@ def download(
"Download and installation successful",
f"You can now load the package via spacy.load('{model_name}')",
)
if is_in_jupyter():
reload_deps_msg = (
"If you are in a Jupyter or Colab notebook, you may need to "
"restart Python in order to load all the package's dependencies. "
"You can do this by selecting the 'Restart kernel' or 'Restart "
"runtime' option."
)
msg.warn(
"Restart to reload dependencies",
reload_deps_msg,
)
elif is_in_interactive():
reload_deps_msg = (
"If you are in an interactive Python session, you may need to "
"exit and restart Python to load all the package's dependencies. "
"You can exit with Ctrl-D (or Ctrl-Z and Enter on Windows)."
)
msg.warn(
"Restart to reload dependencies",
reload_deps_msg,
)


def get_model_filename(model_name: str, version: str, sdist: bool = False) -> str:
Expand Down
1 change: 0 additions & 1 deletion spacy/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,6 @@ class Errors(metaclass=ErrorsWithCodes):
E002 = ("Can't find factory for '{name}' for language {lang} ({lang_code}). "
"This usually happens when spaCy calls `nlp.{method}` with a custom "
"component name that's not registered on the current language class. "
"If you're using a Transformer, make sure to install 'spacy-transformers'. "
"If you're using a custom component, make sure you've added the "
"decorator `@Language.component` (for function components) or "
"`@Language.factory` (for class components).\n\nAvailable "
Expand Down
30 changes: 24 additions & 6 deletions spacy/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1077,20 +1077,38 @@ def force_remove(rmfunc, path, ex):


def is_in_jupyter() -> bool:
"""Check if user is running spaCy from a Jupyter notebook by detecting the
IPython kernel. Mainly used for the displaCy visualizer.
RETURNS (bool): True if in Jupyter, False if not.
"""Check if user is running spaCy from a Jupyter or Colab notebook by
detecting the IPython kernel. Mainly used for the displaCy visualizer.
RETURNS (bool): True if in Jupyter/Colab, False if not.
"""
# https://stackoverflow.com/a/39662359/6400719
# https://stackoverflow.com/questions/15411967
try:
shell = get_ipython().__class__.__name__ # type: ignore[name-defined]
if shell == "ZMQInteractiveShell":
if get_ipython().__class__.__name__ == "ZMQInteractiveShell": # type: ignore[name-defined]
return True # Jupyter notebook or qtconsole
if get_ipython().__class__.__module__ == "google.colab._shell": # type: ignore[name-defined]
return True # Colab notebook
except NameError:
return False # Probably standard Python interpreter
pass # Probably standard Python interpreter
# additional check for Colab
try:
import google.colab

return True # Colab notebook
except ImportError:
pass
return False


def is_in_interactive() -> bool:
"""Check if user is running spaCy from an interactive Python
shell. Will return True in Jupyter notebooks too.
RETURNS (bool): True if in interactive mode, False if not.
"""
# https://stackoverflow.com/questions/2356399/tell-if-python-is-in-interactive-mode
return hasattr(sys, "ps1") or hasattr(sys, "ps2")


def get_object_name(obj: Any) -> str:
"""Get a human-readable name of a Python object, e.g. a pipeline component.

Expand Down