XLA and non-XLA may differ by small floating-point tolerances, but under the same model, seed, input tensor, optimizer, precision policy, and two-worker MultiWorkerMirroredStrategy, the final loss and gradient norm should not diverge by this magnitude after only 20 steps.
Because both workers agree with each other inside each arm, the difference appears specifically between the jit_compile=False and jit_compile=True forward/backward paths.
## minimal_repro_mwms_xla.py
#!/usr/bin/env python3
"""Minimal two-worker TensorFlow reproducer for MWMS XLA/non-XLA divergence.
Run this same file on two GPU machines at approximately the same time.
Machine 0:
python minimal_repro_mwms_xla.py \
--config tensorflow_official_repro_config.json \
--worker-index 0 \
--worker-addresses host0:61120,host1:62120
Machine 1:
python minimal_repro_mwms_xla.py \
--config tensorflow_official_repro_config.json \
--worker-index 1 \
--worker-addresses host0:61120,host1:62120
"""
from __future__ import annotations
import argparse
import json
import math
import os
import time
from pathlib import Path
from typing import Any
LOSS_ATOL = 0.15
LOSS_RTOL = 0.10
GRAD_ATOL = 1.0
GRAD_RTOL = 0.20
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--config", default=None, help="Optional JSON config for runtime parameters.")
parser.add_argument("--worker-index", type=int, required=True)
parser.add_argument("--worker-addresses", required=True, help="host0:port0,host1:port1")
parser.add_argument("--cuda-visible-devices", default="0")
parser.add_argument("--steps", type=int, default=20)
parser.add_argument("--per-replica-batch-size", type=int, default=2)
parser.add_argument("--input-size", type=int, default=224)
parser.add_argument("--num-classes", type=int, default=10)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--learning-rate", type=float, default=0.1)
parser.add_argument("--weight-decay", type=float, default=0.01)
parser.add_argument("--precision-policy", default="float32")
parser.add_argument("--communication", choices=["RING", "AUTO"], default="AUTO")
parser.add_argument("--intra-op-threads", type=int, default=2)
parser.add_argument("--inter-op-threads", type=int, default=1)
parser.add_argument("--fail-on-divergence", action="store_true")
return parser.parse_args()
def apply_config(args: argparse.Namespace) -> dict[str, Any]:
if not args.config:
return {}
path = Path(args.config).expanduser().resolve()
cfg = json.loads(path.read_text(encoding="utf-8"))
args.config = str(path)
args.num_classes = int(cfg.get("num_classes", args.num_classes))
args.seed = int(cfg.get("seed", args.seed))
args.learning_rate = float(cfg.get("learning_rate", args.learning_rate))
args.weight_decay = float(cfg.get("weight_decay", args.weight_decay))
args.precision_policy = str(cfg.get("precision_policy", args.precision_policy))
args.intra_op_threads = int(cfg.get("intra_op_parallelism_threads") or args.intra_op_threads)
args.inter_op_threads = int(cfg.get("inter_op_parallelism_threads") or args.inter_op_threads)
shape = cfg.get("input_shape_per_replica")
if isinstance(shape, list) and len(shape) >= 3:
args.per_replica_batch_size = int(shape[0])
args.input_size = int(shape[1])
else:
args.per_replica_batch_size = int(
cfg.get("per_replica_batch_size", args.per_replica_batch_size)
)
# The original training config may contain L1 admission fields such as
# max_steps=5. This reduced differential reproducer keeps 20 steps unless a
# standalone config explicitly provides steps_per_arm.
if "steps_per_arm" in cfg:
args.steps = int(cfg["steps_per_arm"])
communication = str(
cfg.get("communication")
or cfg.get("communication_implementation")
or args.communication
).upper()
if communication in {"RING", "AUTO"}:
args.communication = communication
return cfg
def config_summary(args: argparse.Namespace, cfg: dict[str, Any]) -> dict[str, Any]:
if not cfg:
return {}
return {
"config_file": Path(str(args.config)).name,
"device": cfg.get("device"),
"distributed_strategy": cfg.get("distributed_strategy"),
"tf_strategy_class": cfg.get("tf_strategy_class"),
"communication_implementation": (
cfg.get("communication_implementation") or cfg.get("communication")
),
"precision_policy": cfg.get("precision_policy"),
"optimizer": cfg.get("optimizer"),
"seed": cfg.get("seed"),
"per_replica_batch_size": cfg.get("per_replica_batch_size"),
"input_shape_per_replica": cfg.get("input_shape_per_replica"),
"learning_rate": cfg.get("learning_rate"),
"weight_decay": cfg.get("weight_decay"),
"steps_per_arm": cfg.get("steps_per_arm"),
"run_eagerly": cfg.get("run_eagerly"),
}
def runtime_configuration(args: argparse.Namespace, workers: list[str]) -> dict[str, Any]:
return {
"strategy": "MultiWorkerMirroredStrategy",
"workers": workers,
"num_workers": len(workers),
"gpus_per_worker": 1,
"communication": str(args.communication),
"model": "ResNet50(weights=None, include_top=False, pooling='avg') + Dense(10)",
"input": "synthetic normal tensor generated from a fixed seed",
"steps_per_arm": int(args.steps),
"per_replica_batch_size": int(args.per_replica_batch_size),
"effective_global_batch_size": int(args.per_replica_batch_size) * len(workers),
"input_shape_per_replica": [
int(args.per_replica_batch_size),
int(args.input_size),
int(args.input_size),
3,
],
"num_classes": int(args.num_classes),
"seed": int(args.seed),
"optimizer": "Adam",
"learning_rate": float(args.learning_rate),
"weight_decay": float(args.weight_decay),
"precision_policy": str(args.precision_policy),
"threading": {
"intra_op_parallelism_threads": int(args.intra_op_threads),
"inter_op_parallelism_threads": int(args.inter_op_threads),
},
"arms": [
{"name": "nonxla", "forward_backward_jit_compile": False},
{"name": "xla", "forward_backward_jit_compile": True},
],
"optimizer_step_jit_compile": False,
"comparison_thresholds": {
"loss_atol": LOSS_ATOL,
"loss_rtol": LOSS_RTOL,
"grad_atol": GRAD_ATOL,
"grad_rtol": GRAD_RTOL,
},
}
def setup_environment(args: argparse.Namespace) -> list[str]:
workers = [part.strip() for part in args.worker_addresses.split(",") if part.strip()]
if len(workers) != 2:
raise SystemExit("--worker-addresses must contain exactly two host:port entries")
if args.worker_index not in (0, 1):
raise SystemExit("--worker-index must be 0 or 1")
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.cuda_visible_devices)
os.environ.setdefault("TF_FORCE_GPU_ALLOW_GROWTH", "true")
os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "1")
os.environ["TF_CONFIG"] = json.dumps(
{
"cluster": {"worker": workers},
"task": {"type": "worker", "index": int(args.worker_index)},
},
sort_keys=True,
)
return workers
def make_model(tf: Any, input_size: int, num_classes: int) -> Any:
base = tf.keras.applications.ResNet50(
weights=None,
include_top=False,
input_shape=(input_size, input_size, 3),
pooling="avg",
)
out = tf.keras.layers.Dense(num_classes)(base.output)
return tf.keras.Model(base.input, out)
def run_arm(tf: Any, strategy: Any, args: argparse.Namespace, *, jit_compile: bool) -> dict[str, float]:
tf.keras.utils.set_random_seed(int(args.seed))
tf.keras.mixed_precision.set_global_policy(str(args.precision_policy))
with strategy.scope():
model = make_model(tf, int(args.input_size), int(args.num_classes))
optimizer = tf.keras.optimizers.Adam(
learning_rate=float(args.learning_rate),
weight_decay=float(args.weight_decay),
)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True,
reduction=tf.keras.losses.Reduction.NONE,
)
def grad_fn(x, y):
with tf.GradientTape() as tape:
logits = model(x, training=True)
per_example_loss = loss_fn(y, logits)
loss = tf.nn.compute_average_loss(per_example_loss)
grads = tape.gradient(loss, model.trainable_variables)
grad_norm = tf.linalg.global_norm([g for g in grads if g is not None])
return loss, grads, grad_norm
compiled_grad_fn = tf.function(grad_fn, jit_compile=bool(jit_compile))
def step_fn(x, y):
loss, grads, grad_norm = compiled_grad_fn(x, y)
optimizer.apply_gradients(
(g, v) for g, v in zip(grads, model.trainable_variables) if g is not None
)
return loss, grad_norm
final_loss = float("nan")
final_grad_norm = float("nan")
batch = int(args.per_replica_batch_size)
input_size = int(args.input_size)
for step in range(int(args.steps)):
generator = tf.random.Generator.from_seed(int(args.seed) + step)
x = generator.normal((batch, input_size, input_size, 3))
y = tf.range(batch, dtype=tf.int32) % int(args.num_classes)
per_replica_loss, per_replica_grad_norm = strategy.run(step_fn, args=(x, y))
final_loss = float(
strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_loss, axis=None).numpy()
)
final_grad_norm = float(
strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_grad_norm, axis=None).numpy()
)
return {"loss": final_loss, "grad_norm": final_grad_norm}
def compare(nonxla: dict[str, float], xla: dict[str, float]) -> dict[str, Any]:
loss_abs = abs(nonxla["loss"] - xla["loss"])
grad_abs = abs(nonxla["grad_norm"] - xla["grad_norm"])
loss_tol = LOSS_ATOL + LOSS_RTOL * max(abs(nonxla["loss"]), abs(xla["loss"]))
grad_tol = GRAD_ATOL + GRAD_RTOL * max(abs(nonxla["grad_norm"]), abs(xla["grad_norm"]))
loss_div = math.isfinite(loss_abs) and loss_abs > loss_tol
grad_div = math.isfinite(grad_abs) and grad_abs > grad_tol
return {
"divergent": bool(loss_div or grad_div),
"nonxla_loss": nonxla["loss"],
"xla_loss": xla["loss"],
"loss_abs_diff": loss_abs,
"loss_tolerance": loss_tol,
"nonxla_grad_norm": nonxla["grad_norm"],
"xla_grad_norm": xla["grad_norm"],
"grad_abs_diff": grad_abs,
"grad_tolerance": grad_tol,
}
def main() -> int:
args = parse_args()
source_config = apply_config(args)
workers = setup_environment(args)
import tensorflow as tf
try:
tf.config.threading.set_intra_op_parallelism_threads(int(args.intra_op_threads))
tf.config.threading.set_inter_op_parallelism_threads(int(args.inter_op_threads))
except RuntimeError:
pass
gpus = tf.config.list_physical_devices("GPU")
if not gpus:
raise RuntimeError("This reproduction expects one visible GPU per worker.")
for gpu in gpus:
try:
tf.config.experimental.set_memory_growth(gpu, True)
except RuntimeError:
pass
communication_impl = (
tf.distribute.experimental.CommunicationImplementation.RING
if str(args.communication).upper() == "RING"
else tf.distribute.experimental.CommunicationImplementation.AUTO
)
strategy = tf.distribute.MultiWorkerMirroredStrategy(
communication_options=tf.distribute.experimental.CommunicationOptions(
implementation=communication_impl
)
)
print(
json.dumps(
{
"event": "environment",
"worker_index": int(args.worker_index),
"tensorflow_version": tf.__version__,
"gpus": [str(gpu) for gpu in gpus],
"tf_config": json.loads(os.environ["TF_CONFIG"]),
"workers": workers,
"runtime_configuration": runtime_configuration(args, workers),
"source_config": config_summary(args, source_config),
},
sort_keys=True,
),
flush=True,
)
started = time.perf_counter()
nonxla = run_arm(tf, strategy, args, jit_compile=False)
print(
json.dumps(
{"event": "arm", "worker_index": int(args.worker_index), "arm": "nonxla", **nonxla},
sort_keys=True,
),
flush=True,
)
xla = run_arm(tf, strategy, args, jit_compile=True)
print(
json.dumps(
{"event": "arm", "worker_index": int(args.worker_index), "arm": "xla", **xla},
sort_keys=True,
),
flush=True,
)
result = compare(nonxla, xla)
print(
json.dumps(
{
"event": "summary",
"worker_index": int(args.worker_index),
"status": "DIVERGENCE_REPRODUCED" if result["divergent"] else "not_reproduced",
"wall_time_sec": round(time.perf_counter() - started, 3),
**result,
},
sort_keys=True,
),
flush=True,
)
if args.fail_on_divergence and result["divergent"]:
return 2
return 0
if __name__ == "__main__":
raise SystemExit(main())
## tensorflow_official_repro_config.json
{
"arms": [
{
"forward_backward_jit_compile": false,
"name": "nonxla"
},
{
"forward_backward_jit_compile": true,
"name": "xla"
}
],
"communication_implementation": "AUTO",
"comparison_thresholds": {
"grad_atol": 1.0,
"grad_rtol": 0.2,
"loss_atol": 0.15,
"loss_rtol": 0.1
},
"device": "gpu",
"distributed_strategy": "multi_worker_mirrored",
"effective_global_batch_size": 4,
"framework": "tensorflow",
"input": "deterministic synthetic normal tensor",
"input_shape_per_replica": [
2,
224,
224,
3
],
"inter_op_parallelism_threads": 1,
"intra_op_parallelism_threads": 2,
"learning_rate": 0.1,
"local_device_count": 1,
"model": "resnet50",
"num_classes": 10,
"num_replicas_in_sync": 2,
"num_workers": 2,
"optimizer": "adam",
"optimizer_step_jit_compile": false,
"per_replica_batch_size": 2,
"precision_policy": "float32",
"run_eagerly": false,
"seed": 42,
"steps_per_arm": 20,
"tf_strategy_class": "MultiWorkerMirroredStrategy",
"weight_decay": 0.01,
"worker_addresses": [
"10.60.88.171:61120",
"10.60.210.23:62120"
]
}
Issue type
Bug
Have you reproduced the bug with TensorFlow Nightly?
Yes
Source
source
TensorFlow version
2.18.0
Custom code
Yes
OS platform and distribution
Linux 5.15.0-113-generic x86_64 glibc 2.35
Mobile device
No response
Python version
3.11.15
Bazel version
No response
GCC/compiler version
14.3.0
CUDA/cuDNN version
CUDA: 12.5.1 cuDNN: 9
GPU model and memory
Tesla V100S-PCIE-32GB / 32768 MiB
Current behavior?
Current behavior
Under the same two-worker
MultiWorkerMirroredStrategysetup, both the non-XLA arm and the XLA arm finish successfully. Both workers agree with each other inside each arm, so this does not look like random cross-worker drift.However, the final loss and gradient norm from the
jit_compile=Trueforward/backward arm differ significantly from thejit_compile=Falseforward/backward arm after only 20 steps.One confirmed two-machine run of the minimal reproducer produced:
Expected behavior
XLA and non-XLA may differ by small floating-point tolerances, but under the same model, seed, input tensor, optimizer, precision policy, and two-worker
MultiWorkerMirroredStrategy, the final loss and gradient norm should not diverge by this magnitude after only 20 steps.Because both workers agree with each other inside each arm, the difference appears specifically between the
jit_compile=Falseandjit_compile=Trueforward/backward paths.Standalone code to reproduce the issue
Relevant log output