diff --git a/eland/ml/pytorch/nlp_ml_model.py b/eland/ml/pytorch/nlp_ml_model.py index ff80b729..c485cbd8 100644 --- a/eland/ml/pytorch/nlp_ml_model.py +++ b/eland/ml/pytorch/nlp_ml_model.py @@ -218,6 +218,20 @@ def __init__( self.num_top_classes = num_top_classes +class TextSimilarityInferenceOptions(InferenceConfig): + def __init__( + self, + *, + tokenization: NlpTokenizationConfig, + results_field: t.Optional[str] = None, + text: t.Optional[str] = None, + ): + super().__init__(configuration_type="text_similarity") + self.tokenization = tokenization + self.results_field = results_field + self.text = text + + class TextEmbeddingInferenceOptions(InferenceConfig): def __init__( self, diff --git a/eland/ml/pytorch/transformers.py b/eland/ml/pytorch/transformers.py index ad900766..4b662c03 100644 --- a/eland/ml/pytorch/transformers.py +++ b/eland/ml/pytorch/transformers.py @@ -51,6 +51,7 @@ QuestionAnsweringInferenceOptions, TextClassificationInferenceOptions, TextEmbeddingInferenceOptions, + TextSimilarityInferenceOptions, TrainedModelInput, ZeroShotClassificationInferenceOptions, ) @@ -64,11 +65,16 @@ "text_embedding", "zero_shot_classification", "question_answering", + "text_similarity", } ARCHITECTURE_TO_TASK_TYPE = { "MaskedLM": ["fill_mask", "text_embedding"], "TokenClassification": ["ner"], - "SequenceClassification": ["text_classification", "zero_shot_classification"], + "SequenceClassification": [ + "text_classification", + "zero_shot_classification", + "text_similarity", + ], "QuestionAnswering": ["question_answering"], "DPRQuestionEncoder": ["text_embedding"], "DPRContextEncoder": ["text_embedding"], @@ -82,6 +88,7 @@ "zero_shot_classification": ZeroShotClassificationInferenceOptions, "pass_through": PassThroughInferenceOptions, "question_answering": QuestionAnsweringInferenceOptions, + "text_similarity": TextSimilarityInferenceOptions, } SUPPORTED_TASK_TYPES_NAMES = ", ".join(sorted(SUPPORTED_TASK_TYPES)) SUPPORTED_TOKENIZERS = ( @@ -124,6 +131,12 @@ def task_type_from_model_config(model_config: PretrainedConfig) -> Optional[str] potential_task_types.add(t) if len(potential_task_types) == 0: return None + if ( + "text_classification" in potential_task_types + and model_config.id2label + and len(model_config.id2label) == 1 + ): + return "text_similarity" if len(potential_task_types) > 1: if "zero_shot_classification" in potential_task_types: if model_config.label2id: @@ -529,6 +542,16 @@ def _prepare_inputs(self) -> transformers.BatchEncoding: ) +class _TraceableTextSimilarityModel(_TransformerTraceableModel): + def _prepare_inputs(self) -> transformers.BatchEncoding: + return self._tokenizer( + "What is the meaning of life?" + "The meaning of life, according to the hitchikers guide, is 42.", + padding="max_length", + return_tensors="pt", + ) + + class TransformerModel: def __init__(self, model_id: str, task_type: str, quantize: bool = False): self._model_id = model_id @@ -674,6 +697,12 @@ def _create_traceable_model(self) -> TraceableModel: elif self._task_type == "question_answering": model = _QuestionAnsweringWrapperModule.from_pretrained(self._model_id) return _TraceableQuestionAnsweringModel(self._tokenizer, model) + elif self._task_type == "text_similarity": + model = transformers.AutoModelForSequenceClassification.from_pretrained( + self._model_id, torchscript=True + ) + model = _DistilBertWrapper.try_wrapping(model) + return _TraceableTextSimilarityModel(self._tokenizer, model) else: raise TypeError( f"Unknown task type {self._task_type}, must be one of: {SUPPORTED_TASK_TYPES_NAMES}"