Skip to content

Commit 65c7ca9

Browse files
Add NameEntityRecognition and Q&A deep learning tasks. (#6760)
* NER * QA almost done, runtime error * QA finished * fixes from PR comments * fixed build * build fixes * perf changes * made disposable * fixed not disposing model * added some disposables to TensorFlow for memory * build testing * fixing build * added missing dispose * build fixes * build fixes * testing macos fix
1 parent 321158d commit 65c7ca9

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+90032
-383
lines changed

build/ci/job-template.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ jobs:
6868
steps:
6969
# Extra MacOS step required to install OS-specific dependencies
7070
- ${{ if and(contains(parameters.pool.vmImage, 'macOS'), not(contains(parameters.name, 'cross'))) }}:
71-
- script: export HOMEBREW_NO_INSTALLED_DEPENDENTS_CHECK=TRUE && brew update && brew unlink libomp && brew install $(Build.SourcesDirectory)/build/libomp.rb --build-from-source --formula
71+
- script: export HOMEBREW_NO_INSTALLED_DEPENDENTS_CHECK=TRUE && brew unlink libomp && brew install $(Build.SourcesDirectory)/build/libomp.rb --build-from-source --formula
7272
displayName: Install MacOS build dependencies
7373
# Extra Apple MacOS step required to install OS-specific dependencies
7474
- ${{ if and(contains(parameters.pool.vmImage, 'macOS'), contains(parameters.name, 'cross')) }}:

src/Microsoft.ML.Tokenizers/Model/BPE.cs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,14 @@ public override IReadOnlyList<Token> Tokenize(string sequence)
195195
return null;
196196
}
197197

198+
/// <summary>
199+
/// Map the tokenized Id to the token.
200+
/// </summary>
201+
/// <param name="id">The Id to map to the token.</param>
202+
/// <param name="skipSpecialTokens">Indicate if want to skip the special tokens during the decoding.</param>
203+
/// <returns>The mapped token of the Id.</returns>
204+
public override string? IdToString(int id, bool skipSpecialTokens = false) => throw new NotImplementedException();
205+
198206
/// <summary>
199207
/// Gets the dictionary mapping tokens to Ids.
200208
/// </summary>
@@ -443,6 +451,11 @@ internal List<Token> TokenizeWithCache(string sequence)
443451
return tokens;
444452
}
445453

454+
public override bool IsValidChar(char ch)
455+
{
456+
throw new NotImplementedException();
457+
}
458+
446459
internal static readonly List<Token> EmptyTokensList = new();
447460
}
448461
}

src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,28 @@ public EnglishRoberta(Stream vocabularyStream, Stream mergeStream, Stream highes
135135
public override string? IdToToken(int id, bool skipSpecialTokens = false) =>
136136
skipSpecialTokens && id < 0 ? null : _vocabReverse.TryGetValue(id, out var value) ? value : null;
137137

138+
/// <summary>
139+
/// Map the tokenized Id to the original string.
140+
/// </summary>
141+
/// <param name="id">The Id to map to the string.</param>
142+
/// <param name="skipSpecialTokens">Indicate if want to skip the special tokens during the decoding.</param>
143+
/// <returns>The mapped token of the Id.</returns>
144+
public override string? IdToString(int id, bool skipSpecialTokens = false)
145+
{
146+
if (skipSpecialTokens && id < 0)
147+
return null;
148+
if (_vocabReverse.TryGetValue(id, out var value))
149+
{
150+
var textChars = string.Join("", value)
151+
.Where(c => _unicodeToByte.ContainsKey(c))
152+
.Select(c => _unicodeToByte[c]);
153+
var text = new string(textChars.ToArray());
154+
return text;
155+
}
156+
157+
return null;
158+
}
159+
138160
/// <summary>
139161
/// Save the model data into the vocabulary, merges, and occurrence mapping files.
140162
/// </summary>
@@ -565,6 +587,11 @@ private List<Token> BpeToken(Span<char> token, Span<int> indexMapping)
565587

566588
return pairs;
567589
}
590+
591+
public override bool IsValidChar(char ch)
592+
{
593+
return _byteToUnicode.ContainsKey(ch);
594+
}
568595
}
569596

570597
/// <summary>

src/Microsoft.ML.Tokenizers/Model/Model.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ public abstract class Model
3535
/// <returns>The mapped token of the Id.</returns>
3636
public abstract string? IdToToken(int id, bool skipSpecialTokens = false);
3737

38+
public abstract string? IdToString(int id, bool skipSpecialTokens = false);
39+
3840
/// <summary>
3941
/// Gets the dictionary mapping tokens to Ids.
4042
/// </summary>
@@ -57,6 +59,14 @@ public abstract class Model
5759
/// Gets a trainer object to use in training the model.
5860
/// </summary>
5961
public abstract Trainer? GetTrainer();
62+
63+
/// <summary>
64+
/// Return true if the char is valid in the tokenizer; otherwise return false.
65+
/// </summary>
66+
/// <param name="ch"></param>
67+
/// <returns></returns>
68+
public abstract bool IsValidChar(char ch);
69+
6070
}
6171

6272
}

src/Microsoft.ML.Tokenizers/Tokenizer.cs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,10 @@ public TokenizerResult Encode(string sequence)
137137

138138
foreach (int id in ids)
139139
{
140-
tokens.Add(Model.IdToToken(id) ?? "");
140+
if (Model.GetType() == typeof(EnglishRoberta))
141+
tokens.Add(Model.IdToString(id) ?? "");
142+
else
143+
tokens.Add(Model.IdToToken(id) ?? "");
141144
}
142145

143146
return Decoder?.Decode(tokens) ?? string.Join("", tokens);
@@ -187,5 +190,10 @@ public void TrainFromFiles(
187190
// To Do: support added vocabulary in the tokenizer which will include this returned special_tokens.
188191
// self.add_special_tokens(&special_tokens);
189192
}
193+
194+
public bool IsValidChar(char ch)
195+
{
196+
return Model.IsValidChar(ch);
197+
}
190198
}
191199
}

src/Microsoft.ML.TorchSharp/AutoFormerV2/ObjectDetectionTrainer.cs

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,12 @@
1717
using static TorchSharp.torch.optim.lr_scheduler;
1818
using Microsoft.ML.TorchSharp.Utils;
1919
using Microsoft.ML;
20-
using Microsoft.ML.TorchSharp.NasBert;
2120
using System.IO;
2221
using Microsoft.ML.Data.IO;
2322
using Microsoft.ML.TorchSharp.Loss;
2423
using Microsoft.ML.Transforms.Image;
2524
using static Microsoft.ML.TorchSharp.AutoFormerV2.ObjectDetectionTrainer;
2625
using Microsoft.ML.TorchSharp.AutoFormerV2;
27-
using Microsoft.ML.Tokenizers;
28-
using Microsoft.ML.TorchSharp.Extensions;
29-
using Microsoft.ML.TorchSharp.NasBert.Models;
30-
using static Microsoft.ML.TorchSharp.NasBert.NasBertTrainer;
31-
using TorchSharp.Modules;
32-
using System.Text;
3326
using static Microsoft.ML.Data.AnnotationUtils;
3427

3528
[assembly: LoadableClass(typeof(ObjectDetectionTransformer), null, typeof(SignatureLoadModel),
@@ -503,7 +496,7 @@ private void CheckInputSchema(SchemaShape inputSchema)
503496
}
504497
}
505498

506-
public class ObjectDetectionTransformer : RowToRowTransformerBase
499+
public class ObjectDetectionTransformer : RowToRowTransformerBase, IDisposable
507500
{
508501
private protected readonly Device Device;
509502
private protected readonly AutoFormerV2 Model;
@@ -522,6 +515,7 @@ public class ObjectDetectionTransformer : RowToRowTransformerBase
522515

523516
private static readonly FuncStaticMethodInfo1<object, Delegate> _decodeInitMethodInfo
524517
= new FuncStaticMethodInfo1<object, Delegate>(DecodeInit<int>);
518+
private bool _disposedValue;
525519

526520
internal ObjectDetectionTransformer(IHostEnvironment env, ObjectDetectionTrainer.Options options, AutoFormerV2 model, DataViewSchema.DetachedColumn labelColumn)
527521
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ObjectDetectionTransformer)))
@@ -992,5 +986,31 @@ private protected override Func<int, bool> GetDependenciesCore(Func<int, bool> a
992986
return col => (activeOutput(0) || activeOutput(1) || activeOutput(2)) && _inputColIndices.Any(i => i == col);
993987
}
994988
}
989+
990+
protected virtual void Dispose(bool disposing)
991+
{
992+
if (!_disposedValue)
993+
{
994+
if (disposing)
995+
{
996+
}
997+
998+
Model.Dispose();
999+
_disposedValue = true;
1000+
}
1001+
}
1002+
1003+
~ObjectDetectionTransformer()
1004+
{
1005+
// Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method
1006+
Dispose(disposing: false);
1007+
}
1008+
1009+
public void Dispose()
1010+
{
1011+
// Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method
1012+
Dispose(disposing: true);
1013+
GC.SuppressFinalize(this);
1014+
}
9951015
}
9961016
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
using System.Collections.Generic;
7+
using System.Text;
8+
9+
namespace Microsoft.ML.TorchSharp.NasBert
10+
{
11+
internal enum BertModelType
12+
{
13+
NasBert,
14+
Roberta
15+
}
16+
}

src/Microsoft.ML.TorchSharp/NasBert/BertTaskType.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ public enum BertTaskType
1616
None = 0,
1717
MaskedLM = 1,
1818
TextClassification = 2,
19-
SentenceRegression = 3
19+
SentenceRegression = 3,
20+
NameEntityRecognition = 4,
21+
QuestionAnswering = 5
2022
}
2123
}

src/Microsoft.ML.TorchSharp/NasBert/Models/BaseHead.cs

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,6 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5-
using System;
6-
using System.Collections.Generic;
7-
using System.Text;
85
using TorchSharp;
96

107
namespace Microsoft.ML.TorchSharp.NasBert.Models

src/Microsoft.ML.TorchSharp/NasBert/Models/BaseModel.cs

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,22 @@
33
// See the LICENSE file in the project root for more information.
44

55
using System;
6-
using System.Collections.Generic;
7-
using System.Linq;
8-
using System.Text;
9-
using Microsoft.ML.TorchSharp.Utils;
106
using TorchSharp;
117

128
namespace Microsoft.ML.TorchSharp.NasBert.Models
139
{
1410
internal abstract class BaseModel : torch.nn.Module<torch.Tensor, torch.Tensor, torch.Tensor>
1511
{
1612
protected readonly NasBertTrainer.NasBertOptions Options;
17-
public BertTaskType HeadType => Options.TaskType;
13+
public BertModelType EncoderType => Options.ModelType;
1814

19-
//public ModelType EncoderType => Options.ModelType;
15+
public BertTaskType HeadType => Options.TaskType;
2016

2117
#pragma warning disable CA1024 // Use properties where appropriate: Modules should be fields in TorchSharp
2218
public abstract TransformerEncoder GetEncoder();
19+
20+
public abstract BaseHead GetHead();
21+
2322
#pragma warning restore CA1024 // Use properties where appropriate
2423

2524
protected BaseModel(NasBertTrainer.NasBertOptions options)

0 commit comments

Comments
 (0)