Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix direct usage of TransformerModel #619

Merged
merged 2 commits into from
Oct 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion eland/ml/pytorch/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,10 +584,10 @@ def __init__(
self,
*,
model_id: str,
access_token: Optional[str],
task_type: str,
es_version: Optional[Tuple[int, int, int]] = None,
quantize: bool = False,
access_token: Optional[str] = None,
):
"""
Loads a model from the Hugging Face repository or local file and creates
Expand Down
2 changes: 2 additions & 0 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ def test(session, pandas_version: str):
"python",
"-m",
"pytest",
"-ra",
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This reports all test except passes, which is useful to see why a specific would have been skipped.

"--tb=native",
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shows the Python native traceback, instead of the overly verbose pytest tracebacks that display the source code (I don't need it, I can see it in my editor).

"--cov-report=term-missing",
"--cov=eland/",
"--cov-config=setup.cfg",
Expand Down
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ lightgbm>=2,<4
torch>=1.13.1,<2.0; python_version<'3.11'
# Versions known to be compatible with PyTorch 1.13.1
sentence-transformers>=2.1.0,<=2.2.2; python_version<'3.11'
transformers[torch]>=4.12.0,<=4.27.4; python_version<'3.11'
transformers[torch]>=4.31.0,<=4.33.2; python_version<'3.11'
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A newer version of transformers is needed to get access token support. This was already upgraded in setup.py, but not here. I'll have to avoid the duplication in the future.


#
# Testing
Expand Down
35 changes: 15 additions & 20 deletions tests/ml/pytorch/test_pytorch_model_upload_pytest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import platform
import tempfile

import pytest
Expand Down Expand Up @@ -82,6 +83,14 @@ def setup_and_tear_down():
pass


@pytest.fixture(scope="session")
def quantize():
# quantization does not work on ARM processors
# TODO: It seems that PyTorch 2.0 supports OneDNN for aarch64. We should
# revisit this when we upgrade to PyTorch 2.0.
return platform.machine() not in ["arm64", "aarch64"]


def download_model_and_start_deployment(tmp_dir, quantize, model_id, task):
print("Loading HuggingFace transformer tokenizer and model")
tm = TransformerModel(
Expand All @@ -103,31 +112,17 @@ def download_model_and_start_deployment(tmp_dir, quantize, model_id, task):


class TestPytorchModel:
def __init__(self):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pytest will not run any test classes when an __init__ method is defined. So we had a test but it was not run, and pytest only warns about it. I've switched to a pytest fixture instead.

# quantization does not work on ARM processors
# TODO: It seems that PyTorch 2.0 supports OneDNN for aarch64. We should
# revisit this when we upgrade to PyTorch 2.0.
import platform

self.quantize = (
True if platform.machine() not in ["arm64", "aarch64"] else False
)

@pytest.mark.parametrize("model_id,task,text_input,value", TEXT_PREDICTION_MODELS)
def test_text_prediction(self, model_id, task, text_input, value):
def test_text_prediction(self, model_id, task, text_input, value, quantize):
with tempfile.TemporaryDirectory() as tmp_dir:
ptm = download_model_and_start_deployment(
tmp_dir, self.quantize, model_id, task
)
result = ptm.infer(docs=[{"text_field": text_input}])
assert result["predicted_value"] == value
ptm = download_model_and_start_deployment(tmp_dir, quantize, model_id, task)
results = ptm.infer(docs=[{"text_field": text_input}])
assert results.body["inference_results"][0]["predicted_value"] == value

@pytest.mark.parametrize("model_id,task,text_input", TEXT_EMBEDDING_MODELS)
def test_text_embedding(self, model_id, task, text_input):
def test_text_embedding(self, model_id, task, text_input, quantize):
with tempfile.TemporaryDirectory() as tmp_dir:
ptm = download_model_and_start_deployment(
tmp_dir, self.quantize, model_id, task
)
ptm = download_model_and_start_deployment(tmp_dir, quantize, model_id, task)
ptm.infer(docs=[{"text_field": text_input}])

if ES_VERSION >= (8, 8, 0):
Expand Down