Skip to content

Commit 78df120

Browse files
danieldkjikanter
authored andcommitted
Add spacy.TextCatParametricAttention.v1 (explosion#13201)
* Add spacy.TextCatParametricAttention.v1 This layer provides is a simplification of the ensemble classifier that only uses paramteric attention. We have found empirically that with a sufficient amount of training data, using the ensemble classifier with BoW does not provide significant improvement in classifier accuracy. However, plugging in a BoW classifier does reduce GPU training and inference performance substantially, since it uses a GPU-only kernel. * Fix merge fallout
1 parent 01cd177 commit 78df120

File tree

5 files changed

+47
-5
lines changed

5 files changed

+47
-5
lines changed

pyproject.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@ requires = [
55
"cymem>=2.0.2,<2.1.0",
66
"preshed>=3.0.2,<3.1.0",
77
"murmurhash>=0.28.0,<1.1.0",
8-
"thinc>=9.0.0.dev4,<9.1.0",
9-
"numpy>=1.15.0",
8+
"thinc>=8.2.2,<8.3.0",
9+
"numpy>=1.15.0; python_version < '3.9'",
10+
"numpy>=1.25.0; python_version >= '3.9'",
1011
]
1112
build-backend = "setuptools.build_meta"
1213

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ spacy-legacy>=4.0.0.dev0,<4.1.0
33
spacy-loggers>=1.0.0,<2.0.0
44
cymem>=2.0.2,<2.1.0
55
preshed>=3.0.2,<3.1.0
6-
thinc>=9.0.0.dev4,<9.1.0
6+
thinc>=8.2.2,<8.3.0
77
ml_datasets>=0.2.0,<0.3.0
88
murmurhash>=0.28.0,<1.1.0
99
wasabi>=0.9.1,<1.2.0

setup.cfg

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,15 @@ setup_requires =
3838
cymem>=2.0.2,<2.1.0
3939
preshed>=3.0.2,<3.1.0
4040
murmurhash>=0.28.0,<1.1.0
41-
thinc>=9.0.0.dev4,<9.1.0
41+
thinc>=8.2.2,<8.3.0
4242
install_requires =
4343
# Our libraries
4444
spacy-legacy>=4.0.0.dev0,<4.1.0
4545
spacy-loggers>=1.0.0,<2.0.0
4646
murmurhash>=0.28.0,<1.1.0
4747
cymem>=2.0.2,<2.1.0
4848
preshed>=3.0.2,<3.1.0
49-
thinc>=9.0.0.dev4,<9.1.0
49+
thinc>=8.2.2,<8.3.0
5050
wasabi>=0.9.1,<1.2.0
5151
srsly>=2.4.3,<3.0.0
5252
catalogue>=2.0.6,<2.1.0

spacy/tests/pipeline/test_textcat.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -755,6 +755,9 @@ def test_overfitting_IO_multi():
755755
# CNN V2 (legacy)
756756
("textcat", TRAIN_DATA_SINGLE_LABEL, {"@architectures": "spacy.TextCatCNN.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True}),
757757
("textcat_multilabel", TRAIN_DATA_MULTI_LABEL, {"@architectures": "spacy.TextCatCNN.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False}),
758+
# PARAMETRIC ATTENTION V1
759+
("textcat", TRAIN_DATA_SINGLE_LABEL, {"@architectures": "spacy.TextCatParametricAttention.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True}),
760+
("textcat_multilabel", TRAIN_DATA_MULTI_LABEL, {"@architectures": "spacy.TextCatParametricAttention.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False}),
758761
# REDUCE V1
759762
("textcat", TRAIN_DATA_SINGLE_LABEL, {"@architectures": "spacy.TextCatReduce.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True, "use_reduce_first": True, "use_reduce_last": True, "use_reduce_max": True, "use_reduce_mean": True}),
760763
("textcat_multilabel", TRAIN_DATA_MULTI_LABEL, {"@architectures": "spacy.TextCatReduce.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False, "use_reduce_first": True, "use_reduce_last": True, "use_reduce_max": True, "use_reduce_mean": True}),

website/docs/api/architectures.mdx

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1056,6 +1056,44 @@ the others, but may not be as accurate, especially if texts are short.
10561056
10571057
</Accordion>
10581058
1059+
### spacy.TextCatParametricAttention.v1 {id="TextCatParametricAttention"}
1060+
1061+
> #### Example Config
1062+
>
1063+
> ```ini
1064+
> [model]
1065+
> @architectures = "spacy.TextCatParametricAttention.v1"
1066+
> exclusive_classes = true
1067+
> nO = null
1068+
>
1069+
> [model.tok2vec]
1070+
> @architectures = "spacy.Tok2Vec.v2"
1071+
>
1072+
> [model.tok2vec.embed]
1073+
> @architectures = "spacy.MultiHashEmbed.v2"
1074+
> width = 64
1075+
> rows = [2000, 2000, 1000, 1000, 1000, 1000]
1076+
> attrs = ["ORTH", "LOWER", "PREFIX", "SUFFIX", "SHAPE", "ID"]
1077+
> include_static_vectors = false
1078+
>
1079+
> [model.tok2vec.encode]
1080+
> @architectures = "spacy.MaxoutWindowEncoder.v2"
1081+
> width = ${model.tok2vec.embed.width}
1082+
> window_size = 1
1083+
> maxout_pieces = 3
1084+
> depth = 2
1085+
> ```
1086+
1087+
A neural network model that is built upon Tok2Vec and uses parametric attention
1088+
to attend to tokens that are relevant to text classification.
1089+
1090+
| Name | Description |
1091+
| ------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
1092+
| `tok2vec` | The `tok2vec` layer to build the neural network upon. ~~Model[List[Doc], List[Floats2d]]~~ |
1093+
| `exclusive_classes` | Whether or not categories are mutually exclusive. ~~bool~~ |
1094+
| `nO` | Output dimension, determined by the number of different labels. If not set, the [`TextCategorizer`](/api/textcategorizer) component will set it when `initialize` is called. ~~Optional[int]~~ |
1095+
| **CREATES** | The model using the architecture. ~~Model[List[Doc], Floats2d]~~ |
1096+
10591097
### spacy.TextCatReduce.v1 {id="TextCatReduce"}
10601098
10611099
> #### Example Config

0 commit comments

Comments
 (0)