From 8c4fb1922c7f35ce6115e762213577ad637a00d1 Mon Sep 17 00:00:00 2001 From: Prabod Rathnayaka Date: Tue, 6 Jun 2023 16:06:26 +0000 Subject: [PATCH 1/5] Added Instructor Embeddings --- .../com/johnsnowlabs/ml/ai/Instructor.scala | 205 +++++++++ .../sign/ModelSignatureConstants.scala | 5 + .../nlp/embeddings/InstructorEmbeddings.scala | 433 ++++++++++++++++++ 3 files changed, 643 insertions(+) create mode 100644 src/main/scala/com/johnsnowlabs/ml/ai/Instructor.scala create mode 100644 src/main/scala/com/johnsnowlabs/nlp/embeddings/InstructorEmbeddings.scala diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/Instructor.scala b/src/main/scala/com/johnsnowlabs/ml/ai/Instructor.scala new file mode 100644 index 00000000000000..6902fa50b86480 --- /dev/null +++ b/src/main/scala/com/johnsnowlabs/ml/ai/Instructor.scala @@ -0,0 +1,205 @@ +/* + * Copyright 2017 - 2023 John Snow Labs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.johnsnowlabs.ml.ai + +import com.johnsnowlabs.ml.tensorflow.sentencepiece.SentencePieceWrapper +import com.johnsnowlabs.ml.tensorflow.sign.{ModelSignatureConstants, ModelSignatureManager} +import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper} +import com.johnsnowlabs.nlp.{Annotation, AnnotatorType} + +import scala.collection.JavaConverters._ + +/** InstructorEmbeddings provides the functionality to generate embeddings for instruction and + * task + * @param tensorflow + * tensorflow wrapper + * @param configProtoBytes + * configProtoBytes + * @param spp + * SentencePieceWrapper + * @param signatures + * signatures + */ + +private[johnsnowlabs] class Instructor( + val tensorflow: TensorflowWrapper, + val spp: SentencePieceWrapper, + configProtoBytes: Option[Array[Byte]] = None, + signatures: Option[Map[String, String]] = None) + extends Serializable { + + private val _tfInstructorSignatures: Map[String, String] = + signatures.getOrElse(ModelSignatureManager.apply()) + private val paddingTokenId = 0 + private val eosTokenId = 1 + + /** + * Get sentence embeddings for a batch of sentences + * @param batch batch of sentences + * @param contextLengths context lengths + * @return sentence embeddings + */ + private def getSentenceEmbedding( + batch: Seq[Array[Int]], + contextLengths: Seq[Int]): Array[Array[Float]] = { + // get max sentence length + val sequencesLength = batch.map(x => x.length).toArray + val maxSentenceLength = sequencesLength.max + val batchLength = batch.length + + // encode batch + val tensorEncoder = new TensorResources() + val inputDim = batch.length * maxSentenceLength + + // create buffers + val encoderInputBuffers = tensorEncoder.createIntBuffer(inputDim) + val encoderAttentionMaskBuffers = tensorEncoder.createIntBuffer(inputDim) + val encoderContextMaskBuffers = tensorEncoder.createIntBuffer(inputDim) + + val shape = Array(batch.length.toLong, maxSentenceLength) + + batch.zipWithIndex.foreach { case (tokenIds, idx) => + val offset = idx * maxSentenceLength + val diff = maxSentenceLength - tokenIds.length + + // pad with 0 + val s = tokenIds.take(maxSentenceLength) ++ Array.fill[Int](diff)(this.paddingTokenId) + encoderInputBuffers.offset(offset).write(s) + + // create attention mask + val mask = s.map(x => if (x != this.paddingTokenId) 1 else 0) + encoderAttentionMaskBuffers.offset(offset).write(mask) + + // create context mask + val contextMask = mask.zipWithIndex.map { + case (x, i) => { if (i < contextLengths(idx)) 0 else x } + } + encoderContextMaskBuffers.offset(offset).write(contextMask) + } + + // create tensors + val encoderInputTensors = tensorEncoder.createIntBufferTensor(shape, encoderInputBuffers) + val encoderAttentionMaskTensors = + tensorEncoder.createIntBufferTensor(shape, encoderAttentionMaskBuffers) + val encoderContextMaskTensors = + tensorEncoder.createIntBufferTensor(shape, encoderContextMaskBuffers) + + // run model + val runner = tensorflow + .getTFSessionWithSignature( + configProtoBytes = configProtoBytes, + initAllTables = false, + savedSignatures = signatures) + .runner + + runner + .feed( + _tfInstructorSignatures.getOrElse( + ModelSignatureConstants.EncoderInputIds.key, + "missing_encoder_input_ids"), + encoderInputTensors) + .feed( + _tfInstructorSignatures.getOrElse( + ModelSignatureConstants.EncoderAttentionMask.key, + "missing_encoder_attention_mask"), + encoderAttentionMaskTensors) + .feed( + _tfInstructorSignatures.getOrElse( + ModelSignatureConstants.EncoderContextMask.key, + "missing_encoder_context_mask"), + encoderContextMaskTensors) + .fetch(_tfInstructorSignatures + .getOrElse(ModelSignatureConstants.LastHiddenState.key, "missing_last_hidden_state")) + + // get embeddings + val sentenceEmbeddings = runner.run().asScala + val sentenceEmbeddingsFloats = TensorResources.extractFloats(sentenceEmbeddings.head) + val dim = sentenceEmbeddingsFloats.length / batchLength + + // group embeddings + val sentenceEmbeddingsFloatsArray = sentenceEmbeddingsFloats.grouped(dim).toArray + + // close buffers + sentenceEmbeddings.foreach(_.close()) + encoderInputTensors.close() + encoderAttentionMaskTensors.close() + encoderContextMaskTensors.close() + tensorEncoder.clearTensors() + tensorEncoder.clearSession(sentenceEmbeddings) + + sentenceEmbeddingsFloatsArray + } + + /** + * Tokenize sentences + * @param sentences sentences + * @param task task + * @param maxSentenceLength max sentence length + * @return + */ + def tokenize( + sentences: Seq[Annotation], + task: String, + maxSentenceLength: Int): Seq[Array[Int]] = { + sentences.map(s => { + val sentWithTask = if (task.nonEmpty) task.concat("").concat(s.result) else s.result + spp.getSppModel.encodeAsIds(sentWithTask).take(maxSentenceLength - 1) ++ Array( + this.eosTokenId) + }) + } + + + /** + * Predict sentence embeddings + * @param sentences sentences + * @param batchSize batch size + * @param maxSentenceLength max sentence length + * @param instruction instruction + * @return + */ + def predict( + sentences: Seq[Annotation], + batchSize: Int, + maxSentenceLength: Int, + instruction: String): Seq[Annotation] = { + + val instructionTokenized = spp.getSppModel.encodeAsIds(instruction) + // repeat instruction length for each sentence + val instructionTokenizedRepeated: Array[Int] = + Array.fill(sentences.length)(instructionTokenized.length) + + val batchEmbeddings = sentences.grouped(batchSize).toArray.flatMap { batch => + // encode batch + val batchSP = tokenize(batch, instruction, maxSentenceLength) + // get sentence embeddings + val sentenceEmbeddings = getSentenceEmbedding(batchSP, instructionTokenizedRepeated) + + // create annotations + batch.zip(sentenceEmbeddings).map { case (sentence, vectors) => + Annotation( + annotatorType = AnnotatorType.SENTENCE_EMBEDDINGS, + begin = sentence.begin, + end = sentence.end, + result = sentence.result, + metadata = sentence.metadata, + embeddings = vectors) + } + } + batchEmbeddings + } + +} diff --git a/src/main/scala/com/johnsnowlabs/ml/tensorflow/sign/ModelSignatureConstants.scala b/src/main/scala/com/johnsnowlabs/ml/tensorflow/sign/ModelSignatureConstants.scala index 383c6b8751582a..94516652e6f30b 100644 --- a/src/main/scala/com/johnsnowlabs/ml/tensorflow/sign/ModelSignatureConstants.scala +++ b/src/main/scala/com/johnsnowlabs/ml/tensorflow/sign/ModelSignatureConstants.scala @@ -273,6 +273,11 @@ object ModelSignatureConstants { override val value: String = "StatefulPartitionedCall_1:2" } + case object EncoderContextMask extends TFInfoNameMapper { + override val key: String = "encoder_context_mask" + override val value: String = "encoder_encoder_context_mask:0" + } + /** Retrieve signature patterns for a given provider * * @param modelProvider diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/InstructorEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/InstructorEmbeddings.scala new file mode 100644 index 00000000000000..b690ae338f9cd1 --- /dev/null +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/InstructorEmbeddings.scala @@ -0,0 +1,433 @@ +/* + * Copyright 2017-2022 John Snow Labs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.johnsnowlabs.nlp.embeddings + +import com.johnsnowlabs.ml.ai.Instructor +import com.johnsnowlabs.ml.tensorflow.* +import com.johnsnowlabs.ml.tensorflow.sentencepiece.{ + ReadSentencePieceModel, + SentencePieceWrapper, + WriteSentencePieceModel +} +import com.johnsnowlabs.ml.util.LoadExternalModel.{ + loadSentencePieceAsset, + modelSanityCheck, + notSupportedEngineError +} +import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.nlp.* +import com.johnsnowlabs.nlp.serialization.MapFeature +import com.johnsnowlabs.storage.HasStorageRef +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.ml.param.* +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.sql.{DataFrame, SparkSession} +import org.slf4j.{Logger, LoggerFactory} + +/** Pretrained models can be loaded with `pretrained` of the companion object: + * {{{ + * val embeddings = InstructorEmbeddings.pretrained() + * .setInputCols("document") + * .setOutputCol("instructor_embeddings") + * }}} + * The default model is `"instructor_xl"`, if no name is provided. + * + * For available pretrained models please see the + * [[https://sparknlp.org/models?task=Instructor Models Hub]]. + * + * For extended examples of usage, see + * [[https://github.com/JohnSnowLabs/spark-nlp/blob/master/src/test/scala/com/johnsnowlabs/nlp/embeddings/InstructorEmbeddingsTestSpec.scala InstructorEmbeddingsTestSpec]]. + * + * '''Sources''' : + * + * [[https://arxiv.org/abs/2212.09741 One Embedder, Any Task: Instruction-Finetuned Text Embeddings]] + * + * [[https://github.com/HKUNLP/instructor-embedding/ INSTRUCTOR Github Repository]] + * + * ''' Paper abstract ''' + * + * ''We introduce INSTRUCTOR, a new method for computing text embeddings given task instructions: + * every text input is embedded together with instructions explaining the use case (e.g., task + * and domain descriptions). Unlike encoders from prior work that are more specialized, + * INSTRUCTOR is a single embedder that can generate text embeddings tailored to different + * downstream tasks and domains, without any further training. We first annotate instructions for + * 330 diverse tasks and train INSTRUCTOR on this multitask mixture with a contrastive loss. We + * evaluate INSTRUCTOR on 70 embedding evaluation tasks (66 of which are unseen during training), + * ranging from classification and information retrieval to semantic textual similarity and text + * generation evaluation. INSTRUCTOR, while having an order of magnitude fewer parameters than + * the previous best model, achieves state-of-the-art performance, with an average improvement of + * 3.4% compared to the previous best results on the 70 diverse datasets. Our analysis suggests + * that INSTRUCTOR is robust to changes in instructions, and that instruction finetuning + * mitigates the challenge of training a single model on diverse datasets. Our model, code, and + * data are available at this https URL. [[https://instructor-embedding.github.io/]] '' + * + * ==Example== + * {{{ + * import spark.implicits._ + * import com.johnsnowlabs.nlp.base.DocumentAssembler + * import com.johnsnowlabs.nlp.annotators.Tokenizer + * import com.johnsnowlabs.nlp.embeddings.InstructorEmbeddings + * import com.johnsnowlabs.nlp.EmbeddingsFinisher + * import org.apache.spark.ml.Pipeline + * + * val documentAssembler = new DocumentAssembler() + * .setInputCol("text") + * .setOutputCol("document") + * + * val embeddings = InstructorEmbeddings.pretrained("instructor_xl", "en") + * .setInputCols("document") + * .setInstruction("Represent the Medicine sentence for clustering: ") + * .setOutputCol("instructor_embeddings") + * + * val embeddingsFinisher = new EmbeddingsFinisher() + * .setInputCols("instructor_embeddings") + * .setOutputCols("finished_embeddings") + * .setOutputAsVector(true) + * + * val pipeline = new Pipeline().setStages(Array( + * documentAssembler, + * embeddings, + * embeddingsFinisher + * )) + * + * val data = Seq("Dynamical Scalar Degree of Freedom in Horava-Lifshitz Gravity").toDF("text") + * val result = pipeline.fit(data).transform(data) + * + * result.selectExpr("explode(finished_embeddings) as result").show(1, 80) + * +--------------------------------------------------------------------------------+ + * | result| + * +--------------------------------------------------------------------------------+ + * |[-2.3497989177703857,0.480538547039032,-0.3238905668258667,-1.612930893898010...| + * +--------------------------------------------------------------------------------+ + * }}} + * + * @see + * [[https://sparknlp.org/docs/en/annotators Annotators Main Page]] for a list of transformer + * based embeddings + * @param uid + * required uid for storing annotator to disk + * @groupname anno Annotator types + * @groupdesc anno + * Required input and expected output annotator types + * @groupname Ungrouped Members + * @groupname param Parameters + * @groupname setParam Parameter setters + * @groupname getParam Parameter getters + * @groupname Ungrouped Members + * @groupprio param 1 + * @groupprio anno 2 + * @groupprio Ungrouped 3 + * @groupprio setParam 4 + * @groupprio getParam 5 + * @groupdesc param + * A list of (hyper-)parameter keys this annotator can take. Users can set and get the + * parameter values through setters and getters, respectively. + */ +class InstructorEmbeddings(override val uid: String) + extends AnnotatorModel[InstructorEmbeddings] + with HasBatchedAnnotate[InstructorEmbeddings] + with WriteTensorflowModel + with HasEmbeddingsProperties + with HasStorageRef + with WriteSentencePieceModel + with HasCaseSensitiveProperties + with HasEngine { + + /** Annotator reference id. Used to identify elements in metadata or to refer to this annotator + * type + */ + override val inputAnnotatorTypes: Array[String] = + Array(AnnotatorType.DOCUMENT) + override val outputAnnotatorType: AnnotatorType = AnnotatorType.SENTENCE_EMBEDDINGS + + /** ConfigProto from tensorflow, serialized into byte array. Get with + * `config_proto.SerializeToString()` + * + * @group param + */ + val configProtoBytes = new IntArrayParam( + this, + "configProtoBytes", + "ConfigProto from tensorflow, serialized into byte array. Get with config_proto.SerializeToString()") + + /** Max sentence length to process (Default: `128`) + * + * @group param + */ + val maxSentenceLength = + new IntParam(this, "maxSentenceLength", "Max sentence length to process") + + /** Set transformer instruction, e.g. 'summarize' format: `"instruction:"`. + * + * @group param + */ + val instruction = + new Param[String](this, "instruction", "Set transformer instruction, e.g. 'summarize'") + + /** It contains TF model signatures for the laded saved model + * + * @group param + */ + val signatures = + new MapFeature[String, String](model = this, name = "signatures").setProtected() + private var _model: Option[Broadcast[Instructor]] = None + + def this() = this(Identifiable.randomUID("INSTRUCTOR_EMBEDDINGS")) + + /** @group setParam */ + def setConfigProtoBytes(bytes: Array[Int]): InstructorEmbeddings.this.type = + set(this.configProtoBytes, bytes) + + /** @group setParam */ + def setMaxSentenceLength(value: Int): this.type = { + require( + value <= 512, + "Instructor models do not support sequences longer than 512 because of trainable positional embeddings.") + require(value >= 1, "The maxSentenceLength must be at least 1") + set(maxSentenceLength, value) + this + } + + /** @group getParam */ + def getMaxSentenceLength: Int = $(maxSentenceLength) + + def setInstruction(value: String): InstructorEmbeddings.this.type = { + set(instruction, value) + this + } + + /** @group setParam */ + def setSignatures(value: Map[String, String]): this.type = { + if (get(signatures).isEmpty) + set(signatures, value) + this + } + + /** @group setParam */ + def setModelIfNotSet( + spark: SparkSession, + tensorflowWrapper: TensorflowWrapper, + spp: SentencePieceWrapper): InstructorEmbeddings = { + if (_model.isEmpty) { + _model = Some( + spark.sparkContext.broadcast( + new Instructor( + tensorflowWrapper, + spp = spp, + configProtoBytes = getConfigProtoBytes, + signatures = getSignatures))) + } + + this + } + + /** Set Embeddings dimensions for the INSTRUCTOR model Only possible to set this when the first time + * is saved dimension is not changeable, it comes from INSTRUCTOR config file + * + * @group setParam + */ + override def setDimension(value: Int): this.type = { + if (get(dimension).isEmpty) + set(this.dimension, value) + this + } + + /** Whether to lowercase tokens or not + * + * @group setParam + */ + override def setCaseSensitive(value: Boolean): this.type = { + if (get(caseSensitive).isEmpty) + set(this.caseSensitive, value) + this + } + + setDefault( + dimension -> 768, + batchSize -> 8, + maxSentenceLength -> 128, + caseSensitive -> false, + instruction -> "") + + /** takes a document and annotations and produces new annotations of this annotator's annotation + * type + * + * @param batchedAnnotations + * Annotations that correspond to inputAnnotationCols generated by previous annotators if any + * @return + * any number of annotations processed for every input annotation. Not necessary one to one + * relationship + */ + override def batchAnnotate(batchedAnnotations: Seq[Array[Annotation]]): Seq[Seq[Annotation]] = { + + val allAnnotations = batchedAnnotations + .filter(_.nonEmpty) + .zipWithIndex + .flatMap { case (annotations, i) => + annotations.filter(_.result.nonEmpty).map(x => (x, i)) + } + val processedAnnotations = if (allAnnotations.nonEmpty) { + this.getModelIfNotSet.predict( + sentences = allAnnotations.map(_._1), + batchSize = $(batchSize), + maxSentenceLength = $(maxSentenceLength), + instruction = $(instruction)) + } else { + Seq() + } + + // Group resulting annotations by rows. If there are not sentences in a given row, return empty sequence + batchedAnnotations.indices.map(rowIndex => { + val rowAnnotations = processedAnnotations + // zip each annotation with its corresponding row index + .zip(allAnnotations) + // select the sentences belonging to the current row + .filter(_._2._2 == rowIndex) + // leave the annotation only + .map(_._1) + + if (rowAnnotations.nonEmpty) + rowAnnotations + else + Seq.empty[Annotation] + }) + + } + + /** @group getParam */ + def getModelIfNotSet: Instructor = _model.get.value + + override def onWrite(path: String, spark: SparkSession): Unit = { + super.onWrite(path, spark) + writeTensorflowModelV2( + path, + spark, + getModelIfNotSet.tensorflow, + "_instructor", + InstructorEmbeddings.tfFile, + configProtoBytes = getConfigProtoBytes, + savedSignatures = getSignatures) + writeSentencePieceModel( + path, + spark, + getModelIfNotSet.spp, + "_instructor", + InstructorEmbeddings.sppFile) + + } + + /** @group getParam */ + def getConfigProtoBytes: Option[Array[Byte]] = get(this.configProtoBytes).map(_.map(_.toByte)) + + /** @group getParam */ + def getSignatures: Option[Map[String, String]] = get(this.signatures) + + override protected def afterAnnotate(dataset: DataFrame): DataFrame = { + dataset.withColumn( + getOutputCol, + wrapSentenceEmbeddingsMetadata( + dataset.col(getOutputCol), + $(dimension), + Some($(storageRef)))) + } + +} + +trait ReadablePretrainedInstructorModel + extends ParamsAndFeaturesReadable[InstructorEmbeddings] + with HasPretrained[InstructorEmbeddings] { + override val defaultModelName: Some[String] = Some("instructor_xl") + + /** Java compliant-overrides */ + override def pretrained(): InstructorEmbeddings = super.pretrained() + + override def pretrained(name: String): InstructorEmbeddings = super.pretrained(name) + + override def pretrained(name: String, lang: String): InstructorEmbeddings = + super.pretrained(name, lang) + + override def pretrained(name: String, lang: String, remoteLoc: String): InstructorEmbeddings = + super.pretrained(name, lang, remoteLoc) +} + +trait ReadInstructorDLModel extends ReadTensorflowModel with ReadSentencePieceModel { + this: ParamsAndFeaturesReadable[InstructorEmbeddings] => + + override val tfFile: String = "instructor_tensorflow" + override val sppFile: String = "instructor_spp" + def readModel(instance: InstructorEmbeddings, path: String, spark: SparkSession): Unit = { + + val tf = readTensorflowModel( + path, + spark, + "_instructor_tf", + savedSignatures = instance.getSignatures, + initAllTables = false) + val spp = readSentencePieceModel(path, spark, "_instructor_spp", sppFile) + instance.setModelIfNotSet(spark, tf, spp) + } + + addReader(readModel) + + def loadSavedModel(modelPath: String, spark: SparkSession): InstructorEmbeddings = { + + val (localModelPath, detectedEngine) = modelSanityCheck(modelPath) + + /*Universal parameters for all engines*/ + val annotatorModel = new InstructorEmbeddings() + + annotatorModel.set(annotatorModel.engine, detectedEngine) + val spModel = loadSentencePieceAsset(localModelPath, "spiece.model") + detectedEngine match { + case ModelEngine.tensorflow => + val (wrapper, signatures) = TensorflowWrapper.read( + localModelPath, + zipped = false, + useBundle = true, + tags = Array("serve"), + initAllTables = false) + + val _signatures = signatures match { + case Some(s) => s + case None => throw new Exception("Cannot load signature definitions from model!") + } + + /** the order of setSignatures is important if we use getSignatures inside + * setModelIfNotSet + */ + annotatorModel + .setSignatures(_signatures) + .setModelIfNotSet(spark, wrapper, spModel) + + case _ => + throw new Exception(notSupportedEngineError) + } + + annotatorModel + } +} + +/** This is the companion object of [[InstructorEmbeddings]]. Please refer to that class for the + * documentation. + */ +object InstructorEmbeddings + extends ReadablePretrainedInstructorModel + with ReadInstructorDLModel + with ReadSentencePieceModel { + private[InstructorEmbeddings] val logger: Logger = + LoggerFactory.getLogger("InstructorEmbeddings") +} From c1193fa46126bf70a8dd4865626f13714876c3bd Mon Sep 17 00:00:00 2001 From: Prabod Rathnayaka Date: Thu, 8 Jun 2023 09:20:04 +0000 Subject: [PATCH 2/5] Added Instructor Embeddings python code --- .../sparknlp/annotator/embeddings/__init__.py | 1 + .../embeddings/instructor_embeddings.py | 204 ++++++++++++++++++ python/sparknlp/internal/__init__.py | 5 + .../embeddings/instructor_embeddings_test.py | 85 ++++++++ .../com/johnsnowlabs/ml/ai/Instructor.scala | 50 +++-- .../nlp/embeddings/InstructorEmbeddings.scala | 33 +-- .../InstructorEmbeddingsTestSpec.scala | 65 ++++++ 7 files changed, 409 insertions(+), 34 deletions(-) create mode 100755 python/sparknlp/annotator/embeddings/instructor_embeddings.py create mode 100644 python/test/annotator/embeddings/instructor_embeddings_test.py create mode 100644 src/test/scala/com/johnsnowlabs/nlp/embeddings/InstructorEmbeddingsTestSpec.scala diff --git a/python/sparknlp/annotator/embeddings/__init__.py b/python/sparknlp/annotator/embeddings/__init__.py index 13501c98901163..02ee80f98fa264 100644 --- a/python/sparknlp/annotator/embeddings/__init__.py +++ b/python/sparknlp/annotator/embeddings/__init__.py @@ -22,6 +22,7 @@ from sparknlp.annotator.embeddings.distil_bert_embeddings import * from sparknlp.annotator.embeddings.doc2vec import * from sparknlp.annotator.embeddings.elmo_embeddings import * +from sparknlp.annotator.embeddings.instructor_embeddings import * from sparknlp.annotator.embeddings.longformer_embeddings import * from sparknlp.annotator.embeddings.roberta_embeddings import * from sparknlp.annotator.embeddings.roberta_sentence_embeddings import * diff --git a/python/sparknlp/annotator/embeddings/instructor_embeddings.py b/python/sparknlp/annotator/embeddings/instructor_embeddings.py new file mode 100755 index 00000000000000..31ca3c7fd52723 --- /dev/null +++ b/python/sparknlp/annotator/embeddings/instructor_embeddings.py @@ -0,0 +1,204 @@ +# Copyright 2017-2022 John Snow Labs +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains classes for BertEmbeddings.""" + +from sparknlp.common import * + + +class InstructorEmbeddings(AnnotatorModel, + HasEmbeddingsProperties, + HasCaseSensitiveProperties, + HasStorageRef, + HasBatchedAnnotate, + HasMaxSentenceLengthLimit): + """Sentence embeddings using INSTRUCTOR. + + Instructor👨‍🏫, an instruction-finetuned text embedding model that can generate text embeddings tailored to any task (e.g., classification, retrieval, clustering, text evaluation, etc.) and domains (e.g., science, finance, etc.) by simply providing the task instruction, without any finetuning. Instructor👨‍ achieves sota on 70 diverse embedding tasks! + + Pretrained models can be loaded with :meth:`.pretrained` of the companion + object: + + >>> embeddings = InstructorEmbeddings.pretrained() \\ + ... .setInputCols(["document"]) \\ + ... .setInstruction("Represent the Medicine sentence for clustering: ") \\ + ... .setOutputCol("instructor_embeddings") + + + The default model is ``"instructor_base"``, if no name is provided. + + For available pretrained models please see the + `Models Hub `__. + + + ====================== ====================== + Input Annotation types Output Annotation type + ====================== ====================== + ``DOCUMENT`` ``SENTENCE_EMBEDDINGS`` + ====================== ====================== + + Parameters + ---------- + batchSize + Size of every batch , by default 8 + dimension + Number of embedding dimensions, by default 768 + caseSensitive + Whether to ignore case in tokens for embeddings matching, by default False + instruction + Set transformer instruction, e.g. 'summarize:' + maxSentenceLength + Max sentence length to process, by default 128 + configProtoBytes + ConfigProto from tensorflow, serialized into byte array. + + References + ---------- + `One Embedder, Any Task: Instruction-Finetuned Text Embeddings `__ + + https://github.com/HKUNLP/instructor-embedding/ + + **Paper abstract** + + *We introduce INSTRUCTOR, a new method for computing text embeddings given task instructions: + every text input is embedded together with instructions explaining the use case (e.g., task and + domain descriptions). Unlike encoders from prior work that are more specialized, INSTRUCTOR is a + single embedder that can generate text embeddings tailored to different downstream tasks and domains, + without any further training. We first annotate instructions for 330 diverse tasks and train INSTRUCTOR + on this multitask mixture with a contrastive loss. We evaluate INSTRUCTOR on 70 embedding evaluation tasks + (66 of which are unseen during training), ranging from classification and information retrieval to semantic + textual similarity and text generation evaluation. INSTRUCTOR, while having an order of magnitude fewer + parameters than the previous best model, achieves state-of-the-art performance, with an average improvement + of 3.4% compared to the previous best results on the 70 diverse datasets. Our analysis suggests that + INSTRUCTOR is robust to changes in instructions, and that instruction finetuning mitigates the challenge of + training a single model on diverse datasets. Our model, code, and data are available at this https + URL .* + + Examples + -------- + >>> import sparknlp + >>> from sparknlp.base import * + >>> from sparknlp.annotator import * + >>> from pyspark.ml import Pipeline + >>> documentAssembler = DocumentAssembler() \\ + ... .setInputCol("text") \\ + ... .setOutputCol("document") + >>> embeddings = InstructorEmbeddings.pretrained() \\ + ... .setInputCols(["document"]) \\ + ... .setInstruction("Represent the Medicine sentence for clustering: ") \\ + ... .setOutputCol("instructor_embeddings") + >>> embeddingsFinisher = EmbeddingsFinisher() \\ + ... .setInputCols(["instructor_embeddings"]) \\ + ... .setOutputCols("finished_embeddings") \\ + ... .setOutputAsVector(True) + >>> pipeline = Pipeline().setStages([ + ... documentAssembler, + ... embeddings, + ... embeddingsFinisher + ... ]) + >>> data = spark.createDataFrame([["Dynamical Scalar Degree of Freedom in Horava-Lifshitz Gravity"]]).toDF("text") + >>> result = pipeline.fit(data).transform(data) + >>> result.selectExpr("explode(finished_embeddings) as result").show(5, 80) + +--------------------------------------------------------------------------------+ + | result| + +--------------------------------------------------------------------------------+ + |[-2.3497989177703857,0.480538547039032,-0.3238905668258667,-1.612930893898010...| + +--------------------------------------------------------------------------------+ + """ + + name = "InstructorEmbeddings" + + inputAnnotatorTypes = [AnnotatorType.DOCUMENT] + + outputAnnotatorType = AnnotatorType.SENTENCE_EMBEDDINGS + instruction = Param(Params._dummy(), "instruction", "Set transformer instruction, e.g. 'summarize:'", + typeConverter=TypeConverters.toString) + configProtoBytes = Param(Params._dummy(), + "configProtoBytes", + "ConfigProto from tensorflow, serialized into byte array. Get with config_proto.SerializeToString()", + TypeConverters.toListInt) + + def setInstruction(self, value): + """ Sets transformer instruction, e.g. 'summarize:'. + + Parameters + ---------- + value : str + """ + return self._set(instruction=value) + + def setConfigProtoBytes(self, b): + """Sets configProto from tensorflow, serialized into byte array. + + Parameters + ---------- + b : List[int] + ConfigProto from tensorflow, serialized into byte array + """ + return self._set(configProtoBytes=b) + + @keyword_only + def __init__(self, classname="com.johnsnowlabs.nlp.embeddings.InstructorEmbeddings", java_model=None): + super(InstructorEmbeddings, self).__init__( + classname=classname, + java_model=java_model + ) + self._setDefault( + dimension=768, + batchSize=8, + maxSentenceLength=128, + caseSensitive=False, + instruction="", + ) + + @staticmethod + def loadSavedModel(folder, spark_session): + """Loads a locally saved model. + + Parameters + ---------- + folder : str + Folder of the saved model + spark_session : pyspark.sql.SparkSession + The current SparkSession + + Returns + ------- + InstructorEmbeddings + The restored model + """ + from sparknlp.internal import _InstructorLoader + jModel = _InstructorLoader(folder, spark_session._jsparkSession)._java_obj + return InstructorEmbeddings(java_model=jModel) + + @staticmethod + def pretrained(name="instructor_base", lang="en", remote_loc=None): + """Downloads and loads a pretrained model. + + Parameters + ---------- + name : str, optional + Name of the pretrained model, by default "instructor_base" + lang : str, optional + Language of the pretrained model, by default "en" + remote_loc : str, optional + Optional remote address of the resource, by default None. Will use + Spark NLPs repositories otherwise. + + Returns + ------- + InstructorEmbeddings + The restored model + """ + from sparknlp.pretrained import ResourceDownloader + return ResourceDownloader.downloadModel(InstructorEmbeddings, name, lang, remote_loc) diff --git a/python/sparknlp/internal/__init__.py b/python/sparknlp/internal/__init__.py index a525a53ca2047a..3c8fae34737540 100644 --- a/python/sparknlp/internal/__init__.py +++ b/python/sparknlp/internal/__init__.py @@ -529,3 +529,8 @@ def __init__(self, path, jspark): super(_RoBertaForZeroShotClassification, self).__init__( "com.johnsnowlabs.nlp.annotators.classifier.dl.RoBertaForZeroShotClassification.loadSavedModel", path, jspark) + + +class _InstructorLoader(ExtendedJavaWrapper): + def __init__(self, path, jspark): + super(_InstructorLoader, self).__init__("com.johnsnowlabs.nlp.embeddings.InstructorEmbeddings.loadSavedModel", path, jspark) \ No newline at end of file diff --git a/python/test/annotator/embeddings/instructor_embeddings_test.py b/python/test/annotator/embeddings/instructor_embeddings_test.py new file mode 100644 index 00000000000000..4dcfd32f060d45 --- /dev/null +++ b/python/test/annotator/embeddings/instructor_embeddings_test.py @@ -0,0 +1,85 @@ +# Copyright 2017-2022 John Snow Labs +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import unittest + +import pytest + +from sparknlp.annotator import * +from sparknlp.base import * +from test.annotator.common.has_max_sentence_length_test import HasMaxSentenceLengthTests +from test.util import SparkContextForTest + + +@pytest.mark.slow +class InstructorEmbeddingsTestSpec(unittest.TestCase): + def setUp(self): + self.spark = SparkContextForTest.spark + self.tested_annotator = InstructorEmbeddings.pretrained() \ + .setInstruction("Represent the Wikipedia document for retrieval: ") \ + .setInputCols(["documents"]) \ + .setOutputCol("instructor") + + def runTest(self): + data = self.spark.createDataFrame([ + [1, """Capitalism has been dominant in the Western world since the end of feudalism, but most feel[who?] that + the term "mixed economies" more precisely describes most contemporary economies, due to their containing both + private-owned and state-owned enterprises. In capitalism, prices determine the demand-supply scale. For + example, higher demand for certain goods and services lead to higher prices and lower demand for certain + goods lead to lower prices. """]]).toDF("id", "text") + + document_assembler = DocumentAssembler() \ + .setInputCol("text") \ + .setOutputCol("documents") + + instruction = self.tested_annotator + + pipeline = Pipeline().setStages([document_assembler, instruction]) + results = pipeline.fit(data).transform(data) + + results.select("instructor.embeddings").show(truncate=False) + +# +# @pytest.mark.slow +# class BertEmbeddingsLoadSavedModelTestSpec(unittest.TestCase): +# +# def setUp(self): +# self.data = SparkContextForTest.spark.read.option("header", "true") \ +# .csv(path="file:///" + os.getcwd() + "/../src/test/resources/embeddings/sentence_embeddings.csv") +# +# def runTest(self): +# document_assembler = DocumentAssembler() \ +# .setInputCol("text") \ +# .setOutputCol("document") +# sentence_detector = SentenceDetector() \ +# .setInputCols(["document"]) \ +# .setOutputCol("sentence") +# tokenizer = Tokenizer() \ +# .setInputCols(["sentence"]) \ +# .setOutputCol("token") +# albert = BertEmbeddings.loadSavedModel(os.getcwd() + "/../src/test/resources/tf-hub-bert/model", +# SparkContextForTest.spark) \ +# .setInputCols(["sentence", "token"]) \ +# .setOutputCol("embeddings") +# +# pipeline = Pipeline(stages=[ +# document_assembler, +# sentence_detector, +# tokenizer, +# albert +# ]) +# +# model = pipeline.fit(self.data) +# model.write().overwrite().save("./tmp_bert_pipeline_model") +# model.transform(self.data).show() diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/Instructor.scala b/src/main/scala/com/johnsnowlabs/ml/ai/Instructor.scala index 6902fa50b86480..1507da55a59de7 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/Instructor.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/Instructor.scala @@ -47,12 +47,14 @@ private[johnsnowlabs] class Instructor( private val paddingTokenId = 0 private val eosTokenId = 1 - /** - * Get sentence embeddings for a batch of sentences - * @param batch batch of sentences - * @param contextLengths context lengths - * @return sentence embeddings - */ + /** Get sentence embeddings for a batch of sentences + * @param batch + * batch of sentences + * @param contextLengths + * context lengths + * @return + * sentence embeddings + */ private def getSentenceEmbedding( batch: Seq[Array[Int]], contextLengths: Seq[Int]): Array[Array[Float]] = { @@ -144,13 +146,15 @@ private[johnsnowlabs] class Instructor( sentenceEmbeddingsFloatsArray } - /** - * Tokenize sentences - * @param sentences sentences - * @param task task - * @param maxSentenceLength max sentence length - * @return - */ + /** Tokenize sentences + * @param sentences + * sentences + * @param task + * task + * @param maxSentenceLength + * max sentence length + * @return + */ def tokenize( sentences: Seq[Annotation], task: String, @@ -162,15 +166,17 @@ private[johnsnowlabs] class Instructor( }) } - - /** - * Predict sentence embeddings - * @param sentences sentences - * @param batchSize batch size - * @param maxSentenceLength max sentence length - * @param instruction instruction - * @return - */ + /** Predict sentence embeddings + * @param sentences + * sentences + * @param batchSize + * batch size + * @param maxSentenceLength + * max sentence length + * @param instruction + * instruction + * @return + */ def predict( sentences: Seq[Annotation], batchSize: Int, diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/InstructorEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/InstructorEmbeddings.scala index b690ae338f9cd1..e7f7f5b45cc02e 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/InstructorEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/InstructorEmbeddings.scala @@ -17,7 +17,7 @@ package com.johnsnowlabs.nlp.embeddings import com.johnsnowlabs.ml.ai.Instructor -import com.johnsnowlabs.ml.tensorflow.* +import com.johnsnowlabs.ml.tensorflow._ import com.johnsnowlabs.ml.tensorflow.sentencepiece.{ ReadSentencePieceModel, SentencePieceWrapper, @@ -29,25 +29,33 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ notSupportedEngineError } import com.johnsnowlabs.ml.util.ModelEngine -import com.johnsnowlabs.nlp.* +import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.serialization.MapFeature import com.johnsnowlabs.storage.HasStorageRef import org.apache.spark.broadcast.Broadcast -import org.apache.spark.ml.param.* +import org.apache.spark.ml.param._ import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.storage.StorageLevel import org.slf4j.{Logger, LoggerFactory} -/** Pretrained models can be loaded with `pretrained` of the companion object: +/** Sentence embeddings using INSTRUCTOR. + * + * Instructor👨‍🏫, an instruction-finetuned text embedding model that can generate text + * embeddings tailored to any task (e.g., classification, retrieval, clustering, text evaluation, + * etc.) and domains (e.g., science, finance, etc.) by simply providing the task instruction, + * without any finetuning. Instructor👨‍ achieves sota on 70 diverse embedding tasks! + * + * Pretrained models can be loaded with `pretrained` of the companion object: * {{{ * val embeddings = InstructorEmbeddings.pretrained() * .setInputCols("document") * .setOutputCol("instructor_embeddings") * }}} - * The default model is `"instructor_xl"`, if no name is provided. + * The default model is `"instructor_base"`, if no name is provided. * * For available pretrained models please see the - * [[https://sparknlp.org/models?task=Instructor Models Hub]]. + * [[https://sparknlp.org/models?q=Instructor Models Hub]]. * * For extended examples of usage, see * [[https://github.com/JohnSnowLabs/spark-nlp/blob/master/src/test/scala/com/johnsnowlabs/nlp/embeddings/InstructorEmbeddingsTestSpec.scala InstructorEmbeddingsTestSpec]]. @@ -88,9 +96,9 @@ import org.slf4j.{Logger, LoggerFactory} * .setInputCol("text") * .setOutputCol("document") * - * val embeddings = InstructorEmbeddings.pretrained("instructor_xl", "en") + * val embeddings = InstructorEmbeddings.pretrained("instructor_base", "en") * .setInputCols("document") - * .setInstruction("Represent the Medicine sentence for clustering: ") + * .setInstruction("Represent the Medicine sentence for clustering: ") * .setOutputCol("instructor_embeddings") * * val embeddingsFinisher = new EmbeddingsFinisher() @@ -229,14 +237,15 @@ class InstructorEmbeddings(override val uid: String) tensorflowWrapper, spp = spp, configProtoBytes = getConfigProtoBytes, - signatures = getSignatures))) + signatures = getSignatures), + StorageLevel.MEMORY_AND_DISK)) } this } - /** Set Embeddings dimensions for the INSTRUCTOR model Only possible to set this when the first time - * is saved dimension is not changeable, it comes from INSTRUCTOR config file + /** Set Embeddings dimensions for the BERT model Only possible to set this when the first time + * is saved dimension is not changeable, it comes from BERT config file * * @group setParam */ @@ -350,7 +359,7 @@ class InstructorEmbeddings(override val uid: String) trait ReadablePretrainedInstructorModel extends ParamsAndFeaturesReadable[InstructorEmbeddings] with HasPretrained[InstructorEmbeddings] { - override val defaultModelName: Some[String] = Some("instructor_xl") + override val defaultModelName: Some[String] = Some("instructor_base") /** Java compliant-overrides */ override def pretrained(): InstructorEmbeddings = super.pretrained() diff --git a/src/test/scala/com/johnsnowlabs/nlp/embeddings/InstructorEmbeddingsTestSpec.scala b/src/test/scala/com/johnsnowlabs/nlp/embeddings/InstructorEmbeddingsTestSpec.scala new file mode 100644 index 00000000000000..33fa7af5da5437 --- /dev/null +++ b/src/test/scala/com/johnsnowlabs/nlp/embeddings/InstructorEmbeddingsTestSpec.scala @@ -0,0 +1,65 @@ +/* + * Copyright 2017-2022 John Snow Labs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.johnsnowlabs.nlp.embeddings + +import com.johnsnowlabs.nlp.base.DocumentAssembler +import com.johnsnowlabs.nlp.util.io.ResourceHelper +import com.johnsnowlabs.tags.FastTest +import org.apache.spark.ml.Pipeline +import org.scalatest.flatspec.AnyFlatSpec + +class InstructorEmbeddingsTestSpec extends AnyFlatSpec { + + "Instructor Embeddings" should "correctly embed multiple sentences" taggedAs FastTest in { + + import ResourceHelper.spark.implicits._ + + val ddd = Seq( + "Capitalism has been dominant in the Western world since the end of feudalism, but most feel[who?]" + + " that the term \"mixed economies\" more precisely describes most contemporary economies, due to their " + + "containing both private-owned and state-owned enterprises. In capitalism, prices determine the " + + "demand-supply scale. For example, higher demand for certain goods and services lead to higher prices " + + "and lower demand for certain goods lead to lower prices.", + "The disparate impact theory is especially controversial under the Fair Housing Act because the Act " + + "regulates many activities relating to housing, insurance, and mortgage loans—and some scholars" + + " have argued that the theory's use under the Fair Housing Act, combined with extensions of the " + + "Community Reinvestment Act, contributed to rise of sub-prime lending and the crash of the U.S. " + + "housing market and ensuing global economic recession", + "Disparate impact in United States labor law refers to practices in employment, housing, and other" + + " areas that adversely affect one group of people of a protected characteristic more than another, " + + "even though rules applied by employers or landlords are formally neutral. Although the protected classes " + + "vary by statute, most federal civil rights laws protect based on race, color, religion, national origin, " + + "and sex as protected traits, and some laws include disability status and other traits as well.") + .toDF("text") + + val document = new DocumentAssembler() + .setInputCol("text") + .setOutputCol("document") + + val embeddings = InstructorEmbeddings + .pretrained() + .setInstruction("Represent the Wikipedia document for retrieval: ") + .setInputCols(Array("document")) + .setOutputCol("instructor") + + val pipeline = new Pipeline().setStages(Array(document, embeddings)) + + val pipelineDF = pipeline.fit(ddd).transform(ddd) + pipelineDF.select("instructor.embeddings").show(truncate = false) + + } +} From 9e064a9e5f7f746ad28da49af4fd3a583f5efecb Mon Sep 17 00:00:00 2001 From: Prabod Rathnayaka Date: Thu, 8 Jun 2023 11:00:27 +0000 Subject: [PATCH 3/5] fixed broadcast bug --- .../nlp/embeddings/InstructorEmbeddings.scala | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/InstructorEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/InstructorEmbeddings.scala index e7f7f5b45cc02e..6e2f94fe7f6a64 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/InstructorEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/InstructorEmbeddings.scala @@ -17,7 +17,7 @@ package com.johnsnowlabs.nlp.embeddings import com.johnsnowlabs.ml.ai.Instructor -import com.johnsnowlabs.ml.tensorflow._ +import com.johnsnowlabs.ml.tensorflow.* import com.johnsnowlabs.ml.tensorflow.sentencepiece.{ ReadSentencePieceModel, SentencePieceWrapper, @@ -29,14 +29,13 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ notSupportedEngineError } import com.johnsnowlabs.ml.util.ModelEngine -import com.johnsnowlabs.nlp._ +import com.johnsnowlabs.nlp.* import com.johnsnowlabs.nlp.serialization.MapFeature import com.johnsnowlabs.storage.HasStorageRef import org.apache.spark.broadcast.Broadcast -import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.* import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.{DataFrame, SparkSession} -import org.apache.spark.storage.StorageLevel import org.slf4j.{Logger, LoggerFactory} /** Sentence embeddings using INSTRUCTOR. @@ -237,8 +236,7 @@ class InstructorEmbeddings(override val uid: String) tensorflowWrapper, spp = spp, configProtoBytes = getConfigProtoBytes, - signatures = getSignatures), - StorageLevel.MEMORY_AND_DISK)) + signatures = getSignatures))) } this From ac7f7f9b5fb63bd7e67faef49572c4d745118e18 Mon Sep 17 00:00:00 2001 From: Prabod Rathnayaka Date: Thu, 8 Jun 2023 11:12:47 +0000 Subject: [PATCH 4/5] fixed broadcast bug --- .../johnsnowlabs/nlp/embeddings/InstructorEmbeddings.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/InstructorEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/InstructorEmbeddings.scala index 6e2f94fe7f6a64..8caee7749a00ad 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/InstructorEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/InstructorEmbeddings.scala @@ -17,7 +17,7 @@ package com.johnsnowlabs.nlp.embeddings import com.johnsnowlabs.ml.ai.Instructor -import com.johnsnowlabs.ml.tensorflow.* +import com.johnsnowlabs.ml.tensorflow._ import com.johnsnowlabs.ml.tensorflow.sentencepiece.{ ReadSentencePieceModel, SentencePieceWrapper, @@ -29,11 +29,11 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ notSupportedEngineError } import com.johnsnowlabs.ml.util.ModelEngine -import com.johnsnowlabs.nlp.* +import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.serialization.MapFeature import com.johnsnowlabs.storage.HasStorageRef import org.apache.spark.broadcast.Broadcast -import org.apache.spark.ml.param.* +import org.apache.spark.ml.param._ import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.{DataFrame, SparkSession} import org.slf4j.{Logger, LoggerFactory} From 2d6a689543ef4da87efb29f5ed75c1a7447a8917 Mon Sep 17 00:00:00 2001 From: Prabod Rathnayaka Date: Thu, 8 Jun 2023 11:52:46 +0000 Subject: [PATCH 5/5] Changed test type to slow --- .../nlp/embeddings/InstructorEmbeddingsTestSpec.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/test/scala/com/johnsnowlabs/nlp/embeddings/InstructorEmbeddingsTestSpec.scala b/src/test/scala/com/johnsnowlabs/nlp/embeddings/InstructorEmbeddingsTestSpec.scala index 33fa7af5da5437..717dc494e0c120 100644 --- a/src/test/scala/com/johnsnowlabs/nlp/embeddings/InstructorEmbeddingsTestSpec.scala +++ b/src/test/scala/com/johnsnowlabs/nlp/embeddings/InstructorEmbeddingsTestSpec.scala @@ -18,13 +18,13 @@ package com.johnsnowlabs.nlp.embeddings import com.johnsnowlabs.nlp.base.DocumentAssembler import com.johnsnowlabs.nlp.util.io.ResourceHelper -import com.johnsnowlabs.tags.FastTest +import com.johnsnowlabs.tags.SlowTest import org.apache.spark.ml.Pipeline import org.scalatest.flatspec.AnyFlatSpec class InstructorEmbeddingsTestSpec extends AnyFlatSpec { - "Instructor Embeddings" should "correctly embed multiple sentences" taggedAs FastTest in { + "Instructor Embeddings" should "correctly embed multiple sentences" taggedAs SlowTest in { import ResourceHelper.spark.implicits._