7
7
8
8
from io import StringIO
9
9
10
- from .base import Operation
10
+ from .base import Operation , Control
11
11
12
12
13
13
class DataPlaceholderNode (str ):
@@ -83,6 +83,10 @@ def add_op(self, operation):
83
83
for p in operation .provides :
84
84
self .graph .add_edge (operation , DataPlaceholderNode (p ))
85
85
86
+ if isinstance (operation , Control ) and hasattr (operation , 'condition_needs' ):
87
+ for n in operation .condition_needs :
88
+ self .graph .add_edge (DataPlaceholderNode (n ), operation )
89
+
86
90
# clear compiled steps (must recompile after adding new layers)
87
91
self .steps = []
88
92
@@ -97,6 +101,8 @@ def show_layers(self):
97
101
print ("\t " , "needs: " , step .needs )
98
102
print ("\t " , "provides: " , step .provides )
99
103
print ("\t " , "color: " , step .color )
104
+ if hasattr (step , 'condition_needs' ):
105
+ print ("\t " , "condition needs: " , step .condition_needs )
100
106
print ("" )
101
107
102
108
def compile (self ):
@@ -107,14 +113,37 @@ def compile(self):
107
113
self .steps = []
108
114
109
115
# create an execution order such that each layer's needs are provided.
110
- ordered_nodes = list (nx .dag .topological_sort (self .graph ))
116
+ try :
117
+ def key (node ):
118
+
119
+ if hasattr (node , 'order' ):
120
+ return node .order
121
+ elif isinstance (node , DataPlaceholderNode ):
122
+ return float ('-inf' )
123
+ else :
124
+ return 0
125
+
126
+ ordered_nodes = list (nx .dag .lexicographical_topological_sort (self .graph ,
127
+ key = key ))
128
+ except TypeError as e :
129
+ if self ._debug :
130
+ print ("Lexicographical topological sort failed! Falling back to topological sort." )
131
+
132
+ if not any (map (lambda node : isinstance (node , Control ), self .graph .nodes )):
133
+ ordered_nodes = list (nx .dag .topological_sort (self .graph ))
134
+ else :
135
+ print ("Topological sort failed!" )
136
+ raise e
111
137
112
138
# add Operations evaluation steps, and instructions to free data.
113
139
for i , node in enumerate (ordered_nodes ):
114
140
115
141
if isinstance (node , DataPlaceholderNode ):
116
142
continue
117
143
144
+ elif isinstance (node , Control ):
145
+ self .steps .append (node )
146
+
118
147
elif isinstance (node , Operation ):
119
148
120
149
# add layer to list of steps
@@ -256,11 +285,24 @@ def compute(self, outputs, named_inputs, color=None):
256
285
# Find the subset of steps we need to run to get to the requested
257
286
# outputs from the provided inputs.
258
287
all_steps = self ._find_necessary_steps (outputs , named_inputs , color )
259
-
288
+ # import pdb
260
289
self .times = {}
290
+ if_true = False
261
291
for step in all_steps :
262
292
263
- if isinstance (step , Operation ):
293
+ if isinstance (step , Control ):
294
+ # pdb.set_trace()
295
+ if hasattr (step , 'condition' ):
296
+ if_true = step ._compute_condition (cache )
297
+ if if_true :
298
+ layer_outputs = step ._compute (cache )
299
+ cache .update (layer_outputs )
300
+ elif not if_true :
301
+ layer_outputs = step ._compute (cache )
302
+ cache .update (layer_outputs )
303
+ if_true = False
304
+
305
+ elif isinstance (step , Operation ):
264
306
265
307
if self ._debug :
266
308
print ("-" * 32 )
0 commit comments