Skip to content

Commit e76a39d

Browse files
committed
Add if, elseif, and else.
1 parent 3d8b6dd commit e76a39d

File tree

4 files changed

+115
-4
lines changed

4 files changed

+115
-4
lines changed

graphkit/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@
99
# For backwards compatibility
1010
from .base import Operation
1111
from .network import Network
12+
from .control import If, ElseIf, Else

graphkit/base.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def __init__(self, **kwargs):
5454
self.provides = kwargs.get('provides')
5555
self.params = kwargs.get('params', {})
5656
self.color = kwargs.get('color', None)
57+
self.order = 0
5758

5859
# call _after_init as final step of initialization
5960
self._after_init()
@@ -165,3 +166,23 @@ def __getstate__(self):
165166
state = Operation.__getstate__(self)
166167
state['net'] = self.__dict__['net']
167168
return state
169+
170+
171+
class Control(Operation):
172+
173+
def __init__(self, **kwargs):
174+
super(Control, self).__init__(**kwargs)
175+
176+
def __repr__(self):
177+
"""
178+
Display more informative names for the Operation class
179+
"""
180+
if hasattr(self, 'condition_needs'):
181+
return u"%s(name='%s', needs=%s, provides=%s, condition_needs=%s)" % \
182+
(self.__class__.__name__,
183+
self.name,
184+
self.needs,
185+
self.provides,
186+
self.condition_needs)
187+
else:
188+
return super(Control, self).__repr__()

graphkit/control.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
"""
2+
This sub-module contains statements that can be used for conditional evaluation of the graph.
3+
"""
4+
5+
from .base import Control
6+
from .functional import compose
7+
8+
9+
class If(Control):
10+
11+
def __init__(self, condition_needs, condition, **kwargs):
12+
super(If, self).__init__(**kwargs)
13+
self.condition_needs = condition_needs
14+
self.condition = condition
15+
self.order = 1
16+
17+
def __call__(self, *args):
18+
self.graph = compose(name=self.name)(*args)
19+
return self
20+
21+
def _compute_condition(self, named_inputs):
22+
inputs = [named_inputs[d] for d in self.condition_needs]
23+
return self.condition(*inputs)
24+
25+
def _compute(self, named_inputs):
26+
return self.graph(named_inputs)
27+
28+
29+
class ElseIf(If):
30+
31+
def __init__(self, condition_needs, condition, **kwargs):
32+
super(ElseIf, self).__init__(condition_needs, condition, **kwargs)
33+
self.order = 2
34+
35+
36+
class Else(Control):
37+
38+
def __init__(self, **kwargs):
39+
super(Else, self).__init__(**kwargs)
40+
self.order = 3
41+
42+
def __call__(self, *args):
43+
self.graph = compose(name=self.name)(*args)
44+
return self
45+
46+
def _compute(self, named_inputs):
47+
return self.graph(named_inputs)

graphkit/network.py

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from io import StringIO
99

10-
from .base import Operation
10+
from .base import Operation, Control
1111

1212

1313
class DataPlaceholderNode(str):
@@ -83,6 +83,10 @@ def add_op(self, operation):
8383
for p in operation.provides:
8484
self.graph.add_edge(operation, DataPlaceholderNode(p))
8585

86+
if isinstance(operation, Control) and hasattr(operation, 'condition_needs'):
87+
for n in operation.condition_needs:
88+
self.graph.add_edge(DataPlaceholderNode(n), operation)
89+
8690
# clear compiled steps (must recompile after adding new layers)
8791
self.steps = []
8892

@@ -97,6 +101,8 @@ def show_layers(self):
97101
print("\t", "needs: ", step.needs)
98102
print("\t", "provides: ", step.provides)
99103
print("\t", "color: ", step.color)
104+
if hasattr(step, 'condition_needs'):
105+
print("\t", "condition needs: ", step.condition_needs)
100106
print("")
101107

102108
def compile(self):
@@ -107,14 +113,37 @@ def compile(self):
107113
self.steps = []
108114

109115
# create an execution order such that each layer's needs are provided.
110-
ordered_nodes = list(nx.dag.topological_sort(self.graph))
116+
try:
117+
def key(node):
118+
119+
if hasattr(node, 'order'):
120+
return node.order
121+
elif isinstance(node, DataPlaceholderNode):
122+
return float('-inf')
123+
else:
124+
return 0
125+
126+
ordered_nodes = list(nx.dag.lexicographical_topological_sort(self.graph,
127+
key=key))
128+
except TypeError as e:
129+
if self._debug:
130+
print("Lexicographical topological sort failed! Falling back to topological sort.")
131+
132+
if not any(map(lambda node: isinstance(node, Control), self.graph.nodes)):
133+
ordered_nodes = list(nx.dag.topological_sort(self.graph))
134+
else:
135+
print("Topological sort failed!")
136+
raise e
111137

112138
# add Operations evaluation steps, and instructions to free data.
113139
for i, node in enumerate(ordered_nodes):
114140

115141
if isinstance(node, DataPlaceholderNode):
116142
continue
117143

144+
elif isinstance(node, Control):
145+
self.steps.append(node)
146+
118147
elif isinstance(node, Operation):
119148

120149
# add layer to list of steps
@@ -256,11 +285,24 @@ def compute(self, outputs, named_inputs, color=None):
256285
# Find the subset of steps we need to run to get to the requested
257286
# outputs from the provided inputs.
258287
all_steps = self._find_necessary_steps(outputs, named_inputs, color)
259-
288+
# import pdb
260289
self.times = {}
290+
if_true = False
261291
for step in all_steps:
262292

263-
if isinstance(step, Operation):
293+
if isinstance(step, Control):
294+
# pdb.set_trace()
295+
if hasattr(step, 'condition'):
296+
if_true = step._compute_condition(cache)
297+
if if_true:
298+
layer_outputs = step._compute(cache)
299+
cache.update(layer_outputs)
300+
elif not if_true:
301+
layer_outputs = step._compute(cache)
302+
cache.update(layer_outputs)
303+
if_true = False
304+
305+
elif isinstance(step, Operation):
264306

265307
if self._debug:
266308
print("-"*32)

0 commit comments

Comments
 (0)