Skip to content

Commit deb8241

Browse files
committed
feat(subtract): added Subtract function to wire
PR: google#382
1 parent e57deea commit deb8241

File tree

11 files changed

+385
-1
lines changed

11 files changed

+385
-1
lines changed

internal/wire/parse.go

Lines changed: 113 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -546,6 +546,9 @@ func (oc *objectCache) processExpr(info *types.Info, pkgPath string, expr ast.Ex
546546
case "NewSet":
547547
pset, errs := oc.processNewSet(info, pkgPath, call, nil, varName)
548548
return pset, notePositionAll(exprPos, errs)
549+
case "Subtract":
550+
pset, errs := oc.processSubtract(info, pkgPath, call, nil, varName)
551+
return pset, notePositionAll(exprPos, errs)
549552
case "Bind":
550553
b, err := processBind(oc.fset, info, call)
551554
if err != nil {
@@ -880,6 +883,116 @@ func isPrevented(tag string) bool {
880883
return reflect.StructTag(tag).Get("wire") == "-"
881884
}
882885

886+
func (oc *objectCache) processSubtract(info *types.Info, pkgPath string, call *ast.CallExpr, args *InjectorArgs, varName string) (interface{}, []error) {
887+
// Assumes that call.Fun is wire.Subtract.
888+
if len(call.Args) < 2 {
889+
return nil, []error{notePosition(oc.fset.Position(call.Pos()),
890+
errors.New("call to Subtract must specify types to be subtracted"))}
891+
}
892+
firstArg, errs := oc.processExpr(info, pkgPath, call.Args[0], "")
893+
if len(errs) > 0 {
894+
return nil, errs
895+
}
896+
set, ok := firstArg.(*ProviderSet)
897+
if !ok {
898+
return nil, []error{
899+
notePosition(oc.fset.Position(call.Pos()),
900+
fmt.Errorf("first argument to Subtract must be a Set")),
901+
}
902+
}
903+
pset := &ProviderSet{
904+
Pos: call.Pos(),
905+
InjectorArgs: args,
906+
PkgPath: pkgPath,
907+
VarName: varName,
908+
// Copy the other fields.
909+
Providers: set.Providers,
910+
Bindings: set.Bindings,
911+
Values: set.Values,
912+
Fields: set.Fields,
913+
Imports: set.Imports,
914+
}
915+
ec := new(errorCollector)
916+
for _, arg := range call.Args[1:] {
917+
ptr, ok := info.TypeOf(arg).(*types.Pointer)
918+
if !ok {
919+
ec.add(notePosition(oc.fset.Position(arg.Pos()),
920+
fmt.Errorf("argument to Subtract must be a pointer"),
921+
))
922+
continue
923+
}
924+
ec.add(oc.filterType(pset, ptr.Elem())...)
925+
}
926+
if len(ec.errors) > 0 {
927+
return nil, ec.errors
928+
}
929+
return pset, nil
930+
}
931+
932+
func (oc *objectCache) filterType(set *ProviderSet, t types.Type) []error {
933+
hasType := func(outs []types.Type) bool {
934+
for _, o := range outs {
935+
if types.Identical(o, t) {
936+
return true
937+
}
938+
pt, ok := o.(*types.Pointer)
939+
if ok && types.Identical(pt.Elem(), t) {
940+
return true
941+
}
942+
}
943+
return false
944+
}
945+
946+
providers := make([]*Provider, 0, len(set.Providers))
947+
for _, p := range set.Providers {
948+
if !hasType(p.Out) {
949+
providers = append(providers, p)
950+
}
951+
}
952+
set.Providers = providers
953+
954+
bindings := make([]*IfaceBinding, 0, len(set.Bindings))
955+
for _, i := range set.Bindings {
956+
if !types.Identical(i.Iface, t) {
957+
bindings = append(bindings, i)
958+
}
959+
}
960+
set.Bindings = bindings
961+
962+
values := make([]*Value, 0, len(set.Values))
963+
for _, v := range set.Values {
964+
if !types.Identical(v.Out, t) {
965+
values = append(values, v)
966+
}
967+
}
968+
set.Values = values
969+
970+
fields := make([]*Field, 0, len(set.Fields))
971+
for _, f := range set.Fields {
972+
if !hasType(f.Out) {
973+
fields = append(fields, f)
974+
}
975+
}
976+
set.Fields = fields
977+
978+
imports := make([]*ProviderSet, 0, len(set.Imports))
979+
for _, p := range set.Imports {
980+
clone := *p
981+
if errs := oc.filterType(&clone, t); len(errs) > 0 {
982+
return errs
983+
}
984+
imports = append(imports, &clone)
985+
}
986+
set.Imports = imports
987+
988+
var errs []error
989+
set.providerMap, set.srcMap, errs = buildProviderMap(oc.fset, oc.hasher, set)
990+
if len(errs) > 0 {
991+
return errs
992+
}
993+
return nil
994+
}
995+
883996
// processBind creates an interface binding from a wire.Bind call.
884997
func processBind(fset *token.FileSet, info *types.Info, call *ast.CallExpr) (*IfaceBinding, error) {
885998
// Assumes that call.Fun is wire.Bind.
@@ -1122,7 +1235,6 @@ func findInjectorBuild(info *types.Info, fn *ast.FuncDecl) (*ast.CallExpr, error
11221235
default:
11231236
invalid = true
11241237
}
1125-
11261238
}
11271239
if wireBuildCall == nil {
11281240
return nil, nil
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
// Copyright 2018 The Wire Authors
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package main
16+
17+
import (
18+
"github.com/google/wire"
19+
)
20+
21+
type context struct{}
22+
23+
func main() {}
24+
25+
type (
26+
FooOptions struct{}
27+
Foo string
28+
Bar struct{}
29+
BarName string
30+
)
31+
32+
func (b *Bar) Bar() {}
33+
34+
func provideFooOptions() *FooOptions {
35+
return &FooOptions{}
36+
}
37+
38+
func provideFoo(*FooOptions) Foo {
39+
return Foo("foo")
40+
}
41+
42+
func provideBar(Foo, BarName) *Bar {
43+
return &Bar{}
44+
}
45+
46+
type BarService interface {
47+
Bar()
48+
}
49+
50+
type FooBar struct {
51+
BarService
52+
Foo
53+
}
54+
55+
var Set = wire.NewSet(
56+
provideFooOptions,
57+
provideFoo,
58+
provideBar,
59+
)
60+
61+
var SuperSet = wire.NewSet(Set,
62+
wire.Struct(new(FooBar), "*"),
63+
wire.Bind(new(BarService), new(*Bar)),
64+
)
65+
66+
type FakeBarService struct{}
67+
68+
func (f *FakeBarService) Bar() {}
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
// Copyright 2018 The Wire Authors
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
//go:build wireinject
16+
// +build wireinject
17+
18+
package main
19+
20+
import (
21+
"github.com/google/wire"
22+
)
23+
24+
func inject(name BarName, opts *FooOptions) *Bar {
25+
panic(wire.Build(wire.Subtract(Set, new(FooOptions))))
26+
}
27+
28+
func injectBarService(name BarName, opts *FakeBarService) *FooBar {
29+
panic(wire.Build(
30+
wire.Subtract(SuperSet, new(BarService)),
31+
wire.Bind(new(BarService), new(*FakeBarService)),
32+
))
33+
}
34+
35+
func injectFooBarService(name BarName, opts *FooOptions, bar *FakeBarService) *FooBar {
36+
panic(wire.Build(
37+
wire.Subtract(SuperSet, new(FooOptions), new(BarService)),
38+
wire.Bind(new(BarService), new(*FakeBarService)),
39+
))
40+
}
41+
42+
func injectNone(name BarName, foo Foo, bar *FakeBarService) *FooBar {
43+
panic(wire.Build(
44+
wire.Subtract(SuperSet, new(Foo), new(BarService)),
45+
wire.Bind(new(BarService), new(*FakeBarService)),
46+
))
47+
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
example.com/foo
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+

internal/wire/testdata/Subtract/want/wire_gen.go

Lines changed: 42 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
// Copyright 2018 The Wire Authors
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package main
16+
17+
import (
18+
"github.com/google/wire"
19+
)
20+
21+
type context struct{}
22+
23+
func main() {}
24+
25+
type (
26+
FooOptions struct{}
27+
Foo string
28+
Bar struct{}
29+
BarName string
30+
)
31+
32+
func provideFooOptions() *FooOptions {
33+
return &FooOptions{}
34+
}
35+
36+
func provideFoo(*FooOptions) Foo {
37+
return Foo("foo")
38+
}
39+
40+
func provideBar(Foo, BarName) *Bar {
41+
return &Bar{}
42+
}
43+
44+
var Set = wire.NewSet(
45+
provideFooOptions,
46+
provideFoo,
47+
provideBar,
48+
)
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// Copyright 2018 The Wire Authors
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
//go:build wireinject
16+
// +build wireinject
17+
18+
package main
19+
20+
import (
21+
"github.com/google/wire"
22+
)
23+
24+
func injectMissArgs(opts *FooOptions) Foo {
25+
panic(wire.Build(wire.Subtract(provideFoo)))
26+
}
27+
28+
func injectNonSet(opts *FooOptions) Foo {
29+
panic(wire.Build(wire.Subtract(provideFoo, new(FooOptions))))
30+
}
31+
32+
func injectNonPointer(name BarName, opts *FooOptions) *Bar {
33+
panic(wire.Build(wire.Subtract(Set, FooOptions{})))
34+
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
example.com/foo
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
example.com/foo/wire.go:x:y: call to Subtract must specify types to be subtracted
2+
3+
example.com/foo/wire.go:x:y: first argument to Subtract must be a Set
4+
5+
example.com/foo/wire.go:x:y: argument to Subtract must be a pointer

0 commit comments

Comments
 (0)