@@ -546,6 +546,9 @@ func (oc *objectCache) processExpr(info *types.Info, pkgPath string, expr ast.Ex
546
546
case "NewSet" :
547
547
pset , errs := oc .processNewSet (info , pkgPath , call , nil , varName )
548
548
return pset , notePositionAll (exprPos , errs )
549
+ case "Subtract" :
550
+ pset , errs := oc .processSubtract (info , pkgPath , call , nil , varName )
551
+ return pset , notePositionAll (exprPos , errs )
549
552
case "Bind" :
550
553
b , err := processBind (oc .fset , info , call )
551
554
if err != nil {
@@ -880,6 +883,116 @@ func isPrevented(tag string) bool {
880
883
return reflect .StructTag (tag ).Get ("wire" ) == "-"
881
884
}
882
885
886
+ func (oc * objectCache ) processSubtract (info * types.Info , pkgPath string , call * ast.CallExpr , args * InjectorArgs , varName string ) (interface {}, []error ) {
887
+ // Assumes that call.Fun is wire.Subtract.
888
+ if len (call .Args ) < 2 {
889
+ return nil , []error {notePosition (oc .fset .Position (call .Pos ()),
890
+ errors .New ("call to Subtract must specify types to be subtracted" ))}
891
+ }
892
+ firstArg , errs := oc .processExpr (info , pkgPath , call .Args [0 ], "" )
893
+ if len (errs ) > 0 {
894
+ return nil , errs
895
+ }
896
+ set , ok := firstArg .(* ProviderSet )
897
+ if ! ok {
898
+ return nil , []error {
899
+ notePosition (oc .fset .Position (call .Pos ()),
900
+ fmt .Errorf ("first argument to Subtract must be a Set" )),
901
+ }
902
+ }
903
+ pset := & ProviderSet {
904
+ Pos : call .Pos (),
905
+ InjectorArgs : args ,
906
+ PkgPath : pkgPath ,
907
+ VarName : varName ,
908
+ // Copy the other fields.
909
+ Providers : set .Providers ,
910
+ Bindings : set .Bindings ,
911
+ Values : set .Values ,
912
+ Fields : set .Fields ,
913
+ Imports : set .Imports ,
914
+ }
915
+ ec := new (errorCollector )
916
+ for _ , arg := range call .Args [1 :] {
917
+ ptr , ok := info .TypeOf (arg ).(* types.Pointer )
918
+ if ! ok {
919
+ ec .add (notePosition (oc .fset .Position (arg .Pos ()),
920
+ fmt .Errorf ("argument to Subtract must be a pointer" ),
921
+ ))
922
+ continue
923
+ }
924
+ ec .add (oc .filterType (pset , ptr .Elem ())... )
925
+ }
926
+ if len (ec .errors ) > 0 {
927
+ return nil , ec .errors
928
+ }
929
+ return pset , nil
930
+ }
931
+
932
+ func (oc * objectCache ) filterType (set * ProviderSet , t types.Type ) []error {
933
+ hasType := func (outs []types.Type ) bool {
934
+ for _ , o := range outs {
935
+ if types .Identical (o , t ) {
936
+ return true
937
+ }
938
+ pt , ok := o .(* types.Pointer )
939
+ if ok && types .Identical (pt .Elem (), t ) {
940
+ return true
941
+ }
942
+ }
943
+ return false
944
+ }
945
+
946
+ providers := make ([]* Provider , 0 , len (set .Providers ))
947
+ for _ , p := range set .Providers {
948
+ if ! hasType (p .Out ) {
949
+ providers = append (providers , p )
950
+ }
951
+ }
952
+ set .Providers = providers
953
+
954
+ bindings := make ([]* IfaceBinding , 0 , len (set .Bindings ))
955
+ for _ , i := range set .Bindings {
956
+ if ! types .Identical (i .Iface , t ) {
957
+ bindings = append (bindings , i )
958
+ }
959
+ }
960
+ set .Bindings = bindings
961
+
962
+ values := make ([]* Value , 0 , len (set .Values ))
963
+ for _ , v := range set .Values {
964
+ if ! types .Identical (v .Out , t ) {
965
+ values = append (values , v )
966
+ }
967
+ }
968
+ set .Values = values
969
+
970
+ fields := make ([]* Field , 0 , len (set .Fields ))
971
+ for _ , f := range set .Fields {
972
+ if ! hasType (f .Out ) {
973
+ fields = append (fields , f )
974
+ }
975
+ }
976
+ set .Fields = fields
977
+
978
+ imports := make ([]* ProviderSet , 0 , len (set .Imports ))
979
+ for _ , p := range set .Imports {
980
+ clone := * p
981
+ if errs := oc .filterType (& clone , t ); len (errs ) > 0 {
982
+ return errs
983
+ }
984
+ imports = append (imports , & clone )
985
+ }
986
+ set .Imports = imports
987
+
988
+ var errs []error
989
+ set .providerMap , set .srcMap , errs = buildProviderMap (oc .fset , oc .hasher , set )
990
+ if len (errs ) > 0 {
991
+ return errs
992
+ }
993
+ return nil
994
+ }
995
+
883
996
// processBind creates an interface binding from a wire.Bind call.
884
997
func processBind (fset * token.FileSet , info * types.Info , call * ast.CallExpr ) (* IfaceBinding , error ) {
885
998
// Assumes that call.Fun is wire.Bind.
@@ -1122,7 +1235,6 @@ func findInjectorBuild(info *types.Info, fn *ast.FuncDecl) (*ast.CallExpr, error
1122
1235
default :
1123
1236
invalid = true
1124
1237
}
1125
-
1126
1238
}
1127
1239
if wireBuildCall == nil {
1128
1240
return nil , nil
0 commit comments