Control Flow in TensorFlow & XLA's AutoClustering
In this post we’ll look at an interesting issue that crops up when autoclustering TensorFlow graphs. I’ve deliberately focused more on the problem than on the solution – the possible solutions are, in my opinion, fairly obvious once the problem is clear.
Control flow in TensorFlow
First we need a high level overview of how control flow is represented in TensorFlow graphs.
The canonical reference to control flow in TensorFlow is “Dynamic control flow in largescale machine learning”^{1} but we’ll do a quick ‘n dirty partial overview for this post.
Control Flow in Acyclic Graphs
TensorFlow represents computations as directed graphs where nodes are operations (e.g. matrix multiply) and edges are data flowing between operations (e.g. dense Ndimensional arrays). Data dependencies constrain the producer to execute before the consumer^{2} and there may be control edges between nodes to further constrain their execution order. TensorFlow operations can have multiple outputs, and can have side effects.
TensorFlow graphs represent control flow via “deadness”. During execution some nodes can be “dead” which, roughly speaking, means they’re not executed^{3}. The vast majority of TensorFlow operations obey the following rules:
 A node is dead if any if its inputs are dead.
 If a node is dead, all of its outputs are dead.
 If a node is alive, all of its outputs are alive.
Terminology: above and elsewhere by “alive” I simply mean “not dead”.
There are some special operations that break these rules, as otherwise we’ll only have trivial control flow: TensorFlow has a Switch
operation, which very roughly speaking, fills the role of a “conditional branch”, and a Merge
operation which, again very roughly speaking, is like a “phi” node^{4} ^{5}. In terms of deadness:
Switch
takes two inputs: a predicatepred
and a valuevalue
. It has two outputs: If
pred
is false then the first output is dead and the second output isvalue
 If
pred
is true then the second output is dead and the first output isvalue
If any of the inputs to the
Switch
itself are dead then all outputs are dead. If
Merge
takesN
inputs and propagates one of the live inputs to its output. If all the inputs are dead then theMerge
produces a dead value.
For example, an ifthenelse diamond that computes Condition ? (X  1) : (X + 1)
looks like:
Control Flow in Cyclic Graphs (a.k.a. Loops)
Control flow in cyclic graphs are a straightforward extension of the above: Merge
no longer needs all of its inputs to have executed before it is executed; it just needs to see one live input which it propagates to its output. Thus, a simple for (i = 0; i < 10; i++)
loop looks like this:
In reality things are more complicated because TensorFlow graphs have a concept of “frames”, but that’s not relevant for this post.
A Problem with XLA Clusters
XLA is an optimizing compiler for TensorFlow graphs, and one way (but not the only way) to use XLA is by having TensorFlow automatically invoke XLA on eligible TensorFlow subgraphs^{6}. This involves replacing these supported subgraphs with XlaLaunch
^{7} operations which, when executed by the TensorFlow graph executor, JIT compiles the subgraph using XLA and invokes the resultant executable. This method is called “XLA autoclustering”.
However, given what we’ve seen so far, autoclustering can be problematic for graphs like these:
Legend: the nodes in blue boxes are all compilable by XLA while the nodes in white ellipses are not. A, B, C, X, Y are “normal” TensorFlow operations that follow the simple simple deadness propagation rules mentioned above.
If we cluster the nodes A, B and C into a single XLA cluster (which feels natural) then the clustered graph will look like this:
In the clustered graph (i.e. in the graph with the XLA cluster) both S and T are dead while in the preclustered graph only T was dead. This follows directly from the rules above:
 In the pretransform graph all inputs to A, B and X are live, and therefore all outputs from X are live.
 In the post transform graph at least one input to XLA Cluster is dead, and therefore all outputs from XLA Cluster are dead.
This difference in deadness will cause some nodes in the clustered graph to not execute which should have been executed. In other words this is a miscompile.
Solution: Static Analysis
There are several ways of fixing this, but perhaps the most straightforward way is via a static analysis that can prove whether a TensorFlow node can be clustered safely. This static analysis maps each TensorFlow node to a predicate that is true if and only if the node is alive. For example, given the following graph, the predicate for Add will be “P0 & P1”:
Using this analysis we only cluster nodes that have identical liveness predicates. This ensures that all nodes in the cluster are either
 All dead in the pretransform graph, in which case it is correct to kill all the outputs from the cluster.
 All alive in the pretransform graph, in which case it is correct to propagate a live value to all the outputs from the cluster.
Comparing liveness predicates is necessarily conservative – the “leaves” of the predicates can be symbolic (so predicates can’t always be simplified to True or False) which makes comparing predicates NPcomplete^{8}.
For simplicity we implement the “all nodes have identical liveness” check a little differently – we implement it as “avoid clustering nodes that have inputs with possibly mismatching liveness”. This is equivalent to “all nodes have identical liveness” because XLA clusters are connected (but not strongly connected) and XLA does not support control flow operations like Switch
and Merge
.

Yu, Y., Abadi, M., Barham, P., Brevdo, E., Burrows, M., Davis, A., Dean, J., Ghemawat, S., Harley, T., Hawkins, P. and Isard, M., 2018, April. Dynamic control flow in largescale machine learning. In Proceedings of the Thirteenth EuroSys Conference (p. 18). ACM. ↩

This may seem insignificant but it means that optimization that break data dependencies, like
A * 0
=>0
, are not generally correct asis over TensorFlow graphs. ↩ 
There are some exceptions to this, but they’re not important in the context of this post. ↩

This is a very strained analogy for various reasons not relevant to this post. ↩

There is also a
ControlTrigger
operation that produces a live output irrespective of whether its inputs are dead or not, but it isn’t relevant for this post. ↩ 
Not all TensorFlow operations are supported by XLA, so, in general, some parts of the TensorFlow will still have to be executed by TensorFlow. ↩

Things are going to get a bit more complicated soon to support “lazy compilation” but this statement will still remain correct in essense. ↩

We can translate a 3SAT problem into the question “Can node X and node Y be clustered together” where X has a predicate equivalent the 3SAT formula and Y has the trivial predicate “True”. Similarly, we can create a 3SAT asserting the equivalence of two predicates. ↩