-
Notifications
You must be signed in to change notification settings - Fork 105
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix direct usage of TransformerModel #619
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -112,6 +112,8 @@ def test(session, pandas_version: str): | |
"python", | ||
"-m", | ||
"pytest", | ||
"-ra", | ||
"--tb=native", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This shows the Python native traceback, instead of the overly verbose pytest tracebacks that display the source code (I don't need it, I can see it in my editor). |
||
"--cov-report=term-missing", | ||
"--cov=eland/", | ||
"--cov-config=setup.cfg", | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,7 +22,7 @@ lightgbm>=2,<4 | |
torch>=1.13.1,<2.0; python_version<'3.11' | ||
# Versions known to be compatible with PyTorch 1.13.1 | ||
sentence-transformers>=2.1.0,<=2.2.2; python_version<'3.11' | ||
transformers[torch]>=4.12.0,<=4.27.4; python_version<'3.11' | ||
transformers[torch]>=4.31.0,<=4.33.2; python_version<'3.11' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A newer version of transformers is needed to get access token support. This was already upgraded in setup.py, but not here. I'll have to avoid the duplication in the future. |
||
|
||
# | ||
# Testing | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,7 @@ | |
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
import platform | ||
import tempfile | ||
|
||
import pytest | ||
|
@@ -82,6 +83,14 @@ def setup_and_tear_down(): | |
pass | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def quantize(): | ||
# quantization does not work on ARM processors | ||
# TODO: It seems that PyTorch 2.0 supports OneDNN for aarch64. We should | ||
# revisit this when we upgrade to PyTorch 2.0. | ||
return platform.machine() not in ["arm64", "aarch64"] | ||
|
||
|
||
def download_model_and_start_deployment(tmp_dir, quantize, model_id, task): | ||
print("Loading HuggingFace transformer tokenizer and model") | ||
tm = TransformerModel( | ||
|
@@ -103,31 +112,17 @@ def download_model_and_start_deployment(tmp_dir, quantize, model_id, task): | |
|
||
|
||
class TestPytorchModel: | ||
def __init__(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. pytest will not run any test classes when an |
||
# quantization does not work on ARM processors | ||
# TODO: It seems that PyTorch 2.0 supports OneDNN for aarch64. We should | ||
# revisit this when we upgrade to PyTorch 2.0. | ||
import platform | ||
|
||
self.quantize = ( | ||
True if platform.machine() not in ["arm64", "aarch64"] else False | ||
) | ||
|
||
@pytest.mark.parametrize("model_id,task,text_input,value", TEXT_PREDICTION_MODELS) | ||
def test_text_prediction(self, model_id, task, text_input, value): | ||
def test_text_prediction(self, model_id, task, text_input, value, quantize): | ||
with tempfile.TemporaryDirectory() as tmp_dir: | ||
ptm = download_model_and_start_deployment( | ||
tmp_dir, self.quantize, model_id, task | ||
) | ||
result = ptm.infer(docs=[{"text_field": text_input}]) | ||
assert result["predicted_value"] == value | ||
ptm = download_model_and_start_deployment(tmp_dir, quantize, model_id, task) | ||
results = ptm.infer(docs=[{"text_field": text_input}]) | ||
assert results.body["inference_results"][0]["predicted_value"] == value | ||
|
||
@pytest.mark.parametrize("model_id,task,text_input", TEXT_EMBEDDING_MODELS) | ||
def test_text_embedding(self, model_id, task, text_input): | ||
def test_text_embedding(self, model_id, task, text_input, quantize): | ||
with tempfile.TemporaryDirectory() as tmp_dir: | ||
ptm = download_model_and_start_deployment( | ||
tmp_dir, self.quantize, model_id, task | ||
) | ||
ptm = download_model_and_start_deployment(tmp_dir, quantize, model_id, task) | ||
ptm.infer(docs=[{"text_field": text_input}]) | ||
|
||
if ES_VERSION >= (8, 8, 0): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This reports all test except passes, which is useful to see why a specific would have been skipped.