In this post we’ll look at an interesting issue that crops up when auto-clustering TensorFlow graphs. I’ve deliberately focused more on the problem than on the solution – the possible solutions are, in my opinion, fairly obvious once the problem is clear.
Control flow in TensorFlow
First we need a high level overview of how control flow is represented in TensorFlow graphs.
The canonical reference to control flow in TensorFlow is “Dynamic control flow in large-scale machine learning”1 but we’ll do a quick ‘n dirty partial overview for this post.
Control Flow in Acyclic Graphs
TensorFlow represents computations as directed graphs where nodes are operations (e.g. matrix multiply) and edges are data flowing between operations (e.g. dense N-dimensional arrays). Data dependencies constrain the producer to execute before the consumer2 and there may be control edges between nodes to further constrain their execution order. TensorFlow operations can have multiple outputs, and can have side effects.
TensorFlow graphs represent control flow via “deadness”. During execution some nodes can be “dead” which, roughly speaking, means they’re not executed3. The vast majority of TensorFlow operations obey the following rules:
- A node is dead if any if its inputs are dead.
- If a node is dead, all of its outputs are dead.
- If a node is alive, all of its outputs are alive.
Terminology: above and elsewhere by “alive” I simply mean “not dead”.
There are some special operations that break these rules, as otherwise we’ll only have trivial control flow: TensorFlow has a
Switch operation, which very roughly speaking, fills the role of a “conditional branch”, and a
Merge operation which, again very roughly speaking, is like a “phi” node4 5. In terms of deadness:
Switchtakes two inputs: a predicate
predand a value
value. It has two outputs:
predis false then the first output is dead and the second output is
predis true then the second output is dead and the first output is
If any of the inputs to the
Switchitself are dead then all outputs are dead.
Ninputs and propagates one of the live inputs to its output. If all the inputs are dead then the
Mergeproduces a dead value.
For example, an if-then-else diamond that computes
Condition ? (X - 1) : (X + 1) looks like:
Control Flow in Cyclic Graphs (a.k.a. Loops)
Control flow in cyclic graphs are a straightforward extension of the above:
Merge no longer needs all of its inputs to have executed before it is executed; it just needs to see one live input which it propagates to its output. Thus, a simple
for (i = 0; i < 10; i++) loop looks like this:
In reality things are more complicated because TensorFlow graphs have a concept of “frames”, but that’s not relevant for this post.
A Problem with XLA Clusters
XLA is an optimizing compiler for TensorFlow graphs, and one way (but not the only way) to use XLA is by having TensorFlow automatically invoke XLA on eligible TensorFlow subgraphs6. This involves replacing these supported subgraphs with
XlaLaunch7 operations which, when executed by the TensorFlow graph executor, JIT compiles the subgraph using XLA and invokes the resultant executable. This method is called “XLA auto-clustering”.
However, given what we’ve seen so far, auto-clustering can be problematic for graphs like these:
Legend: the nodes in blue boxes are all compilable by XLA while the nodes in white ellipses are not. A, B, C, X, Y are “normal” TensorFlow operations that follow the simple simple deadness propagation rules mentioned above.
If we cluster the nodes A, B and C into a single XLA cluster (which feels natural) then the clustered graph will look like this:
In the clustered graph (i.e. in the graph with the XLA cluster) both S and T are dead while in the pre-clustered graph only T was dead. This follows directly from the rules above:
- In the pre-transform graph all inputs to A, B and X are live, and therefore all outputs from X are live.
- In the post transform graph at least one input to XLA Cluster is dead, and therefore all outputs from XLA Cluster are dead.
This difference in deadness will cause some nodes in the clustered graph to not execute which should have been executed. In other words this is a miscompile.
Solution: Static Analysis
There are several ways of fixing this, but perhaps the most straightforward way is via a static analysis that can prove whether a TensorFlow node can be clustered safely. This static analysis maps each TensorFlow node to a predicate that is true if and only if the node is alive. For example, given the following graph, the predicate for Add will be “P0 & P1”:
Using this analysis we only cluster nodes that have identical liveness predicates. This ensures that all nodes in the cluster are either
- All dead in the pre-transform graph, in which case it is correct to kill all the outputs from the cluster.
- All alive in the pre-transform graph, in which case it is correct to propagate a live value to all the outputs from the cluster.
Comparing liveness predicates is necessarily conservative – the “leaves” of the predicates can be symbolic (so predicates can’t always be simplified to True or False) which makes comparing predicates NP-complete8.
For simplicity we implement the “all nodes have identical liveness” check a little differently – we implement it as “avoid clustering nodes that have inputs with possibly mismatching liveness”. This is equivalent to “all nodes have identical liveness” because XLA clusters are connected (but not strongly connected) and XLA does not support control flow operations like
Yu, Y., Abadi, M., Barham, P., Brevdo, E., Burrows, M., Davis, A., Dean, J., Ghemawat, S., Harley, T., Hawkins, P. and Isard, M., 2018, April. Dynamic control flow in large-scale machine learning. In Proceedings of the Thirteenth EuroSys Conference (p. 18). ACM. ↩
This may seem insignificant but it means that optimization that break data dependencies, like
A * 0=>
0, are not generally correct as-is over TensorFlow graphs. ↩
There are some exceptions to this, but they’re not important in the context of this post. ↩
This is a very strained analogy for various reasons not relevant to this post. ↩
There is also a
ControlTriggeroperation that produces a live output irrespective of whether its inputs are dead or not, but it isn’t relevant for this post. ↩
Not all TensorFlow operations are supported by XLA, so, in general, some parts of the TensorFlow will still have to be executed by TensorFlow. ↩
Things are going to get a bit more complicated soon to support “lazy compilation” but this statement will still remain correct in essense. ↩
We can translate a 3SAT problem into the question “Can node X and node Y be clustered together” where X has a predicate equivalent the 3SAT formula and Y has the trivial predicate “True”. Similarly, we can create a 3SAT asserting the equivalence of two predicates. ↩