Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

178 changes: 120 additions & 58 deletions src/main/scala/com/johnsnowlabs/ml/ai/CamemBertClassification.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@

package com.johnsnowlabs.ml.ai

import ai.onnxruntime.OnnxTensor
import com.johnsnowlabs.ml.onnx.OnnxWrapper
import com.johnsnowlabs.ml.tensorflow.sentencepiece.{SentencePieceWrapper, SentencepieceEncoder}
import com.johnsnowlabs.ml.tensorflow.sign.{ModelSignatureConstants, ModelSignatureManager}
import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper}
import com.johnsnowlabs.ml.util.{ONNX, TensorFlow}
import com.johnsnowlabs.nlp.annotators.common._
import com.johnsnowlabs.nlp.{ActivationFunction, Annotation}
import org.apache.spark.ml.param.Params
import org.apache.spark.ml.util.Identifiable
import org.tensorflow.ndarray.buffer.{IntDataBuffer, LongDataBuffer}
import org.tensorflow.ndarray.buffer.LongDataBuffer

import scala.collection.JavaConverters._

Expand All @@ -39,7 +40,8 @@ import scala.collection.JavaConverters._
* TF v2 signatures in Spark NLP
*/
private[johnsnowlabs] class CamemBertClassification(
val tensorflowWrapper: TensorflowWrapper,
val tensorflowWrapper: Option[TensorflowWrapper],
val onnxWrapper: Option[OnnxWrapper],
val spp: SentencePieceWrapper,
configProtoBytes: Option[Array[Byte]] = None,
tags: Map[String, Int],
Expand All @@ -50,6 +52,10 @@ private[johnsnowlabs] class CamemBertClassification(

val _tfCamemBertSignatures: Map[String, String] =
signatures.getOrElse(ModelSignatureManager.apply())
val detectedEngine: String =
if (tensorflowWrapper.isDefined) TensorFlow.name
else if (onnxWrapper.isDefined) ONNX.name
else TensorFlow.name

/** HACK: These tokens were added by fairseq but don't seem to be actually used when duplicated
* in the actual # sentencepiece vocabulary (this is the case for '''<s>''' and '''</s>''')
Expand Down Expand Up @@ -112,51 +118,13 @@ private[johnsnowlabs] class CamemBertClassification(
}

def tag(batch: Seq[Array[Int]]): Seq[Array[Array[Float]]] = {
val tensors = new TensorResources()

val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max
val batchLength = batch.length
val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max

val tokenBuffers: IntDataBuffer = tensors.createIntBuffer(batchLength * maxSentenceLength)
val maskBuffers: IntDataBuffer = tensors.createIntBuffer(batchLength * maxSentenceLength)

// [nb of encoded sentences , maxSentenceLength]
val shape = Array(batch.length.toLong, maxSentenceLength)

batch.zipWithIndex
.foreach { case (sentence, idx) =>
val offset = idx * maxSentenceLength
tokenBuffers.offset(offset).write(sentence)
maskBuffers
.offset(offset)
.write(sentence.map(x => if (x == sentencePadTokenId) 0 else 1))
}

val runner = tensorflowWrapper
.getTFSessionWithSignature(configProtoBytes = configProtoBytes, initAllTables = false)
.runner

val tokenTensors = tensors.createIntBufferTensor(shape, tokenBuffers)
val maskTensors = tensors.createIntBufferTensor(shape, maskBuffers)

runner
.feed(
_tfCamemBertSignatures
.getOrElse(ModelSignatureConstants.InputIds.key, "missing_input_id_key"),
tokenTensors)
.feed(
_tfCamemBertSignatures
.getOrElse(ModelSignatureConstants.AttentionMask.key, "missing_input_mask_key"),
maskTensors)
.fetch(_tfCamemBertSignatures
.getOrElse(ModelSignatureConstants.LogitsOutput.key, "missing_logits_key"))

val outs = runner.run().asScala
val rawScores = TensorResources.extractFloats(outs.head)

outs.foreach(_.close())
tensors.clearSession(outs)
tensors.clearTensors()
val rawScores = detectedEngine match {
case ONNX.name => getRowScoresWithOnnx(batch)
case _ => getRawScoresWithTF(batch, maxSentenceLength)
}

val dim = rawScores.length / (batchLength * maxSentenceLength)
val batchScores: Array[Array[Array[Float]]] = rawScores
Expand All @@ -169,7 +137,7 @@ private[johnsnowlabs] class CamemBertClassification(
batchScores
}

def tagSequence(batch: Seq[Array[Int]], activation: String): Array[Array[Float]] = {
private def getRawScoresWithTF(batch: Seq[Array[Int]], maxSentenceLength: Int): Array[Float] = {
val tensors = new TensorResources()

val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max
Expand All @@ -191,7 +159,7 @@ private[johnsnowlabs] class CamemBertClassification(
.write(sentence.map(x => if (x == sentencePadTokenId) 0L else 1L))
}

val runner = tensorflowWrapper
val runner = tensorflowWrapper.get
.getTFSessionWithSignature(configProtoBytes = configProtoBytes, initAllTables = false)
.runner

Expand All @@ -217,6 +185,51 @@ private[johnsnowlabs] class CamemBertClassification(
tensors.clearSession(outs)
tensors.clearTensors()

rawScores
}

private def getRowScoresWithOnnx(batch: Seq[Array[Int]]): Array[Float] = {

// [nb of encoded sentences , maxSentenceLength]
val (runner, env) = onnxWrapper.get.getSession()

val tokenTensors =
OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray)
val maskTensors =
OnnxTensor.createTensor(
env,
batch.map(sentence => sentence.map(x => if (x == 0L) 0L else 1L)).toArray)

val inputs =
Map("input_ids" -> tokenTensors, "attention_mask" -> maskTensors).asJava

try {
val results = runner.run(inputs)
try {
val embeddings = results
.get("logits")
.get()
.asInstanceOf[OnnxTensor]
.getFloatBuffer
.array()
tokenTensors.close()
maskTensors.close()

embeddings
} finally if (results != null) results.close()
}
}

def tagSequence(batch: Seq[Array[Int]], activation: String): Array[Array[Float]] = {

val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max
val batchLength = batch.length

val rawScores = detectedEngine match {
case ONNX.name => getRowScoresWithOnnx(batch)
case _ => getRawScoresWithTF(batch, maxSentenceLength)
}

val dim = rawScores.length / batchLength
val batchScores: Array[Array[Float]] =
rawScores
Expand All @@ -239,6 +252,24 @@ private[johnsnowlabs] class CamemBertClassification(
activation: String): Array[Array[Float]] = ???

def tagSpan(batch: Seq[Array[Int]]): (Array[Array[Float]], Array[Array[Float]]) = {
val batchLength = batch.length
val (startLogits, endLogits) = detectedEngine match {
case ONNX.name => computeLogitsWithOnnx(batch)
case _ => computeLogitsWithTF(batch)
}

val endDim = endLogits.length / batchLength
val endScores: Array[Array[Float]] =
endLogits.grouped(endDim).map(scores => calculateSoftmax(scores)).toArray

val startDim = startLogits.length / batchLength
val startScores: Array[Array[Float]] =
startLogits.grouped(startDim).map(scores => calculateSoftmax(scores)).toArray

(startScores, endScores)
}

private def computeLogitsWithTF(batch: Seq[Array[Int]]): (Array[Float], Array[Float]) = {
val tensors = new TensorResources()

val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max
Expand All @@ -261,7 +292,7 @@ private[johnsnowlabs] class CamemBertClassification(
.write(sentence.map(x => if (x == sentencePadTokenId) 0L else 1L))
}

val runner = tensorflowWrapper
val runner = tensorflowWrapper.get
.getTFSessionWithSignature(configProtoBytes = configProtoBytes, initAllTables = false)
.runner

Expand Down Expand Up @@ -290,15 +321,46 @@ private[johnsnowlabs] class CamemBertClassification(
tensors.clearSession(outs)
tensors.clearTensors()

val endDim = endLogits.length / batchLength
val endScores: Array[Array[Float]] =
endLogits.grouped(endDim).map(scores => calculateSoftmax(scores)).toArray

val startDim = startLogits.length / batchLength
val startScores: Array[Array[Float]] =
startLogits.grouped(startDim).map(scores => calculateSoftmax(scores)).toArray
(startLogits, endLogits)
}

(startScores, endScores)
private def computeLogitsWithOnnx(batch: Seq[Array[Int]]): (Array[Float], Array[Float]) = {
// [nb of encoded sentences]
val (runner, env) = onnxWrapper.get.getSession()

val tokenTensors =
OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray)
val maskTensors =
OnnxTensor.createTensor(
env,
batch.map(sentence => sentence.map(x => if (x == 0L) 0L else 1L)).toArray)

val inputs =
Map("input_ids" -> tokenTensors, "attention_mask" -> maskTensors).asJava

try {
val output = runner.run(inputs)
try {
val startLogits = output
.get("start_logits")
.get()
.asInstanceOf[OnnxTensor]
.getFloatBuffer
.array()

val endLogits = output
.get("end_logits")
.get()
.asInstanceOf[OnnxTensor]
.getFloatBuffer
.array()

tokenTensors.close()
maskTensors.close()

(startLogits, endLogits)
} finally if (output != null) output.close()
}
}

def findIndexedToken(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -348,9 +348,9 @@ trait ReadAlbertForTokenDLModel

instance.getEngine match {
case TensorFlow.name =>
val tf =
val tfWrapper =
readTensorflowModel(path, spark, "_albert_classification_tf", initAllTables = false)
instance.setModelIfNotSet(spark, Some(tf), None, spp)
instance.setModelIfNotSet(spark, Some(tfWrapper), None, spp)
case ONNX.name =>
val onnxWrapper =
readOnnxModel(
Expand Down
Loading