Skip to content

Commit a60be5f

Browse files
authored
Rename NameEntity to NamedEntity (#6917)
1 parent b8f71b9 commit a60be5f

File tree

5 files changed

+59
-16
lines changed

5 files changed

+59
-16
lines changed

src/Microsoft.ML.TorchSharp/NasBert/BertTaskType.cs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
using System;
66
using System.Collections.Generic;
7+
using System.ComponentModel;
78
using System.Text;
89

910
namespace Microsoft.ML.TorchSharp.NasBert
@@ -17,7 +18,10 @@ public enum BertTaskType
1718
MaskedLM = 1,
1819
TextClassification = 2,
1920
SentenceRegression = 3,
20-
NameEntityRecognition = 4,
21+
NamedEntityRecognition = 4,
22+
[Obsolete("Please use NamedEntityRecognition instead", false)]
23+
[EditorBrowsable(EditorBrowsableState.Never)]
24+
NameEntityRecognition = NamedEntityRecognition,
2125
QuestionAnswering = 5
2226
}
2327
}

src/Microsoft.ML.TorchSharp/NasBert/NasBertTrainer.cs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ private protected override Module CreateModule(IChannel ch, IDataView input)
204204
EnglishRoberta tokenizerModel = Tokenizer.RobertaModel();
205205

206206
NasBertModel model;
207-
if (Parent.BertOptions.TaskType == BertTaskType.NameEntityRecognition)
207+
if (Parent.BertOptions.TaskType == BertTaskType.NamedEntityRecognition)
208208
model = new NerModel(Parent.BertOptions, tokenizerModel.PadIndex, tokenizerModel.SymbolsCount, Parent.Option.NumberOfClasses);
209209
else
210210
model = new ModelForPrediction(Parent.BertOptions, tokenizerModel.PadIndex, tokenizerModel.SymbolsCount, Parent.Option.NumberOfClasses);
@@ -268,7 +268,7 @@ private protected override torch.Tensor PrepareRowTensor()
268268
private protected override void RunModelAndBackPropagate(ref List<Tensor> inputTensors, ref Tensor targetsTensor)
269269
{
270270
Tensor logits = default;
271-
if (Parent.BertOptions.TaskType == BertTaskType.NameEntityRecognition)
271+
if (Parent.BertOptions.TaskType == BertTaskType.NamedEntityRecognition)
272272
{
273273
int[,] lengthArray = new int[inputTensors.Count, 1];
274274
for (int i = 0; i < inputTensors.Count; i++)
@@ -293,7 +293,7 @@ private protected override void RunModelAndBackPropagate(ref List<Tensor> inputT
293293
torch.Tensor loss;
294294
if (Parent.BertOptions.TaskType == BertTaskType.TextClassification)
295295
loss = torch.nn.CrossEntropyLoss(reduction: Parent.BertOptions.Reduction).forward(logits, targetsTensor);
296-
else if (Parent.BertOptions.TaskType == BertTaskType.NameEntityRecognition)
296+
else if (Parent.BertOptions.TaskType == BertTaskType.NamedEntityRecognition)
297297
{
298298
targetsTensor = targetsTensor.@long().view(-1);
299299
logits = logits.view(-1, logits.size(-1));
@@ -338,7 +338,7 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
338338
outColumns[Option.ScoreColumnName] = new SchemaShape.Column(Option.ScoreColumnName, SchemaShape.Column.VectorKind.Vector,
339339
NumberDataViewType.Single, false, new SchemaShape(AnnotationUtils.AnnotationsForMulticlassScoreColumn(labelCol)));
340340
}
341-
else if (BertOptions.TaskType == BertTaskType.NameEntityRecognition)
341+
else if (BertOptions.TaskType == BertTaskType.NamedEntityRecognition)
342342
{
343343
var metadata = new List<SchemaShape.Column>();
344344
metadata.Add(new SchemaShape.Column(AnnotationUtils.Kinds.KeyValues, SchemaShape.Column.VectorKind.Vector,
@@ -387,7 +387,7 @@ private protected override void CheckInputSchema(SchemaShape inputSchema)
387387
TextDataViewType.Instance.ToString(), sentenceCol2.GetTypeString());
388388
}
389389
}
390-
else if (BertOptions.TaskType == BertTaskType.NameEntityRecognition)
390+
else if (BertOptions.TaskType == BertTaskType.NamedEntityRecognition)
391391
{
392392
if (labelCol.ItemType != NumberDataViewType.UInt32)
393393
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "label", Option.LabelColumnName,
@@ -535,7 +535,7 @@ protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore()
535535
info[1] = new DataViewSchema.DetachedColumn(Parent.Options.ScoreColumnName, new VectorDataViewType(NumberDataViewType.Single, Parent.Options.NumberOfClasses), meta.ToAnnotations());
536536
return info;
537537
}
538-
else if (Parent.BertOptions.TaskType == BertTaskType.NameEntityRecognition)
538+
else if (Parent.BertOptions.TaskType == BertTaskType.NamedEntityRecognition)
539539
{
540540
var info = new DataViewSchema.DetachedColumn[1];
541541
var keyType = Parent.LabelColumn.Annotations.Schema.GetColumnOrNull(AnnotationUtils.Kinds.KeyValues)?.Type as VectorDataViewType;

src/Microsoft.ML.TorchSharp/NasBert/NerTrainer.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ namespace Microsoft.ML.TorchSharp.NasBert
3535
/// </summary>
3636
/// <remarks>
3737
/// <format type="text/markdown"><![CDATA[
38-
/// To create this trainer, use [NER](xref:Microsoft.ML.TorchSharpCatalog.NameEntityRecognition(Microsoft.ML.MulticlassClassificationCatalog.MulticlassClassificationTrainers,System.String,System.String,System.String,Int32,Int32,Int32,Microsoft.ML.TorchSharp.NasBert.BertArchitecture,Microsoft.ML.IDataView)).
38+
/// To create this trainer, use [NER](xref:Microsoft.ML.TorchSharpCatalog.NamedEntityRecognition(Microsoft.ML.MulticlassClassificationCatalog.MulticlassClassificationTrainers,System.String,System.String,System.String,Int32,Int32,Int32,Microsoft.ML.TorchSharp.NasBert.BertArchitecture,Microsoft.ML.IDataView)).
3939
///
4040
/// ### Input and Output Columns
4141
/// The input label column data must be a Vector of [string](xref:Microsoft.ML.Data.TextDataViewType) type and the sentence columns must be of type<xref:Microsoft.ML.Data.TextDataViewType>.
@@ -54,7 +54,7 @@ namespace Microsoft.ML.TorchSharp.NasBert
5454
/// | Exportable to ONNX | No |
5555
///
5656
/// ### Training Algorithm Details
57-
/// Trains a Deep Neural Network(DNN) by leveraging an existing pre-trained NAS-BERT roBERTa model for the purpose of name entity recognition.
57+
/// Trains a Deep Neural Network(DNN) by leveraging an existing pre-trained NAS-BERT roBERTa model for the purpose of named entity recognition.
5858
/// ]]>
5959
/// </format>
6060
/// </remarks>
@@ -93,7 +93,7 @@ internal NerTrainer(IHostEnvironment env,
9393
BatchSize = batchSize,
9494
MaxEpoch = maxEpochs,
9595
ValidationSet = validationSet,
96-
TaskType = BertTaskType.NameEntityRecognition
96+
TaskType = BertTaskType.NamedEntityRecognition
9797
})
9898
{
9999
}
@@ -295,7 +295,7 @@ private static NerTransformer Create(IHostEnvironment env, ModelLoadContext ctx)
295295

296296
options.Sentence1ColumnName = ctx.LoadString();
297297
options.Sentence2ColumnName = ctx.LoadStringOrNull();
298-
options.TaskType = BertTaskType.NameEntityRecognition;
298+
options.TaskType = BertTaskType.NamedEntityRecognition;
299299

300300
BinarySaver saver = new BinarySaver(env, new BinarySaver.Arguments());
301301
DataViewType type;

src/Microsoft.ML.TorchSharp/TorchSharpCatalog.cs

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
using System;
66
using System.Collections.Generic;
7+
using System.ComponentModel;
78
using System.Text;
89
using Microsoft.ML.Data;
910
using Microsoft.ML.TorchSharp.AutoFormerV2;
@@ -161,7 +162,45 @@ public static ObjectDetectionMetrics EvaluateObjectDetection(
161162
}
162163

163164
/// <summary>
164-
/// Fine tune a NAS-BERT model for Name Entity Recognition. The limit for any sentence is 512 tokens. Each word typically
165+
/// Obsolete: please use the <see cref="NamedEntityRecognition(MulticlassClassificationCatalog.MulticlassClassificationTrainers, string, string, string, int, int, BertArchitecture, IDataView)"/> method instead
166+
/// </summary>
167+
/// <param name="catalog">The transform's catalog.</param>
168+
/// <param name="labelColumnName">Name of the label column. Column should be a key type.</param>
169+
/// <param name="outputColumnName">Name of the output column. It will be a key type. It is the predicted label.</param>
170+
/// <param name="sentence1ColumnName">Name of the column for the first sentence.</param>
171+
/// <param name="batchSize">Number of rows in the batch.</param>
172+
/// <param name="maxEpochs">Maximum number of times to loop through your training set.</param>
173+
/// <param name="architecture">Architecture for the model. Defaults to Roberta.</param>
174+
/// <param name="validationSet">The validation set used while training to improve model quality.</param>
175+
/// <returns></returns>
176+
[Obsolete("Please use NamedEntityRecognition method instead", false)]
177+
[EditorBrowsable(EditorBrowsableState.Never)]
178+
public static NerTrainer NameEntityRecognition(
179+
this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog,
180+
string labelColumnName = DefaultColumnNames.Label,
181+
string outputColumnName = DefaultColumnNames.PredictedLabel,
182+
string sentence1ColumnName = "Sentence",
183+
int batchSize = 32,
184+
int maxEpochs = 10,
185+
BertArchitecture architecture = BertArchitecture.Roberta,
186+
IDataView validationSet = null)
187+
=> NamedEntityRecognition(catalog, labelColumnName, outputColumnName, sentence1ColumnName, batchSize, maxEpochs, architecture, validationSet);
188+
189+
/// <summary>
190+
/// Obsolete: please use the <see cref="NamedEntityRecognition(MulticlassClassificationCatalog.MulticlassClassificationTrainers, NerTrainer.NerOptions)"/> method instead
191+
/// </summary>
192+
/// <param name="catalog">The transform's catalog.</param>
193+
/// <param name="options">The full set of advanced options.</param>
194+
/// <returns></returns>
195+
[Obsolete("Please use NamedEntityRecognition method instead", false)]
196+
[EditorBrowsable(EditorBrowsableState.Never)]
197+
public static NerTrainer NameEntityRecognition(
198+
this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog,
199+
NerTrainer.NerOptions options)
200+
=> NamedEntityRecognition(catalog, options);
201+
202+
/// <summary>
203+
/// Fine tune a NAS-BERT model for Named Entity Recognition. The limit for any sentence is 512 tokens. Each word typically
165204
/// will map to a single token, and we automatically add 2 specical tokens (a start token and a separator token)
166205
/// so in general this limit will be 510 words for all sentences.
167206
/// </summary>
@@ -174,7 +213,7 @@ public static ObjectDetectionMetrics EvaluateObjectDetection(
174213
/// <param name="architecture">Architecture for the model. Defaults to Roberta.</param>
175214
/// <param name="validationSet">The validation set used while training to improve model quality.</param>
176215
/// <returns></returns>
177-
public static NerTrainer NameEntityRecognition(
216+
public static NerTrainer NamedEntityRecognition(
178217
this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog,
179218
string labelColumnName = DefaultColumnNames.Label,
180219
string outputColumnName = DefaultColumnNames.PredictedLabel,
@@ -186,12 +225,12 @@ public static NerTrainer NameEntityRecognition(
186225
=> new NerTrainer(CatalogUtils.GetEnvironment(catalog), labelColumnName, outputColumnName, sentence1ColumnName, batchSize, maxEpochs, validationSet, architecture);
187226

188227
/// <summary>
189-
/// Fine tune a Name Entity Recognition model.
228+
/// Fine tune a Named Entity Recognition model.
190229
/// </summary>
191230
/// <param name="catalog">The transform's catalog.</param>
192231
/// <param name="options">The full set of advanced options.</param>
193232
/// <returns></returns>
194-
public static NerTrainer NameEntityRecognition(
233+
public static NerTrainer NamedEntityRecognition(
195234
this MulticlassClassificationCatalog.MulticlassClassificationTrainers catalog,
196235
NerTrainer.NerOptions options)
197236
=> new NerTrainer(CatalogUtils.GetEnvironment(catalog), options);

test/Microsoft.ML.Tests/NerTests.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ public void TestSimpleNer()
5454
}));
5555
var chain = new EstimatorChain<ITransformer>();
5656
var estimator = chain.Append(ML.Transforms.Conversion.MapValueToKey("Label", keyData: labels))
57-
.Append(ML.MulticlassClassification.Trainers.NameEntityRecognition(outputColumnName: "outputColumn"))
57+
.Append(ML.MulticlassClassification.Trainers.NamedEntityRecognition(outputColumnName: "outputColumn"))
5858
.Append(ML.Transforms.Conversion.MapKeyToValue("outputColumn"));
5959

6060
var estimatorSchema = estimator.GetOutputSchema(SchemaShape.Create(dataView.Schema));

0 commit comments

Comments
 (0)