Skip to content

TensorFlow MultiWorkerMirroredStrategy: XLA and non-XLA forward/backward produce divergent loss and gradient norm with Adam #120580

@whitephilomel

Description

@whitephilomel

Issue type

Bug

Have you reproduced the bug with TensorFlow Nightly?

Yes

Source

source

TensorFlow version

2.18.0

Custom code

Yes

OS platform and distribution

Linux 5.15.0-113-generic x86_64 glibc 2.35

Mobile device

No response

Python version

3.11.15

Bazel version

No response

GCC/compiler version

14.3.0

CUDA/cuDNN version

CUDA: 12.5.1 cuDNN: 9

GPU model and memory

Tesla V100S-PCIE-32GB / 32768 MiB

Current behavior?

Current behavior

Under the same two-worker MultiWorkerMirroredStrategy setup, both the non-XLA arm and the XLA arm finish successfully. Both workers agree with each other inside each arm, so this does not look like random cross-worker drift.

However, the final loss and gradient norm from the jit_compile=True forward/backward arm differ significantly from the jit_compile=False forward/backward arm after only 20 steps.

One confirmed two-machine run of the minimal reproducer produced:

worker non-XLA loss XLA loss loss abs diff non-XLA grad norm XLA grad norm grad abs diff
worker 0 0.0021458889823406935 1.6746044158935547 1.672458526911214 0.02090447209775448 1.9025657176971436 1.881661245599389
worker 1 0.0021458889823406935 1.6746044158935547 1.672458526911214 0.02090447209775448 1.9025657176971436 1.881661245599389

Expected behavior

XLA and non-XLA may differ by small floating-point tolerances, but under the same model, seed, input tensor, optimizer, precision policy, and two-worker MultiWorkerMirroredStrategy, the final loss and gradient norm should not diverge by this magnitude after only 20 steps.

Because both workers agree with each other inside each arm, the difference appears specifically between the jit_compile=False and jit_compile=True forward/backward paths.

Standalone code to reproduce the issue

## minimal_repro_mwms_xla.py

#!/usr/bin/env python3
"""Minimal two-worker TensorFlow reproducer for MWMS XLA/non-XLA divergence.

Run this same file on two GPU machines at approximately the same time.

Machine 0:
  python minimal_repro_mwms_xla.py \
    --config tensorflow_official_repro_config.json \
    --worker-index 0 \
    --worker-addresses host0:61120,host1:62120

Machine 1:
  python minimal_repro_mwms_xla.py \
    --config tensorflow_official_repro_config.json \
    --worker-index 1 \
    --worker-addresses host0:61120,host1:62120
"""

from __future__ import annotations

import argparse
import json
import math
import os
import time
from pathlib import Path
from typing import Any


LOSS_ATOL = 0.15
LOSS_RTOL = 0.10
GRAD_ATOL = 1.0
GRAD_RTOL = 0.20


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", default=None, help="Optional JSON config for runtime parameters.")
    parser.add_argument("--worker-index", type=int, required=True)
    parser.add_argument("--worker-addresses", required=True, help="host0:port0,host1:port1")
    parser.add_argument("--cuda-visible-devices", default="0")
    parser.add_argument("--steps", type=int, default=20)
    parser.add_argument("--per-replica-batch-size", type=int, default=2)
    parser.add_argument("--input-size", type=int, default=224)
    parser.add_argument("--num-classes", type=int, default=10)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--learning-rate", type=float, default=0.1)
    parser.add_argument("--weight-decay", type=float, default=0.01)
    parser.add_argument("--precision-policy", default="float32")
    parser.add_argument("--communication", choices=["RING", "AUTO"], default="AUTO")
    parser.add_argument("--intra-op-threads", type=int, default=2)
    parser.add_argument("--inter-op-threads", type=int, default=1)
    parser.add_argument("--fail-on-divergence", action="store_true")
    return parser.parse_args()


def apply_config(args: argparse.Namespace) -> dict[str, Any]:
    if not args.config:
        return {}

    path = Path(args.config).expanduser().resolve()
    cfg = json.loads(path.read_text(encoding="utf-8"))
    args.config = str(path)

    args.num_classes = int(cfg.get("num_classes", args.num_classes))
    args.seed = int(cfg.get("seed", args.seed))
    args.learning_rate = float(cfg.get("learning_rate", args.learning_rate))
    args.weight_decay = float(cfg.get("weight_decay", args.weight_decay))
    args.precision_policy = str(cfg.get("precision_policy", args.precision_policy))
    args.intra_op_threads = int(cfg.get("intra_op_parallelism_threads") or args.intra_op_threads)
    args.inter_op_threads = int(cfg.get("inter_op_parallelism_threads") or args.inter_op_threads)

    shape = cfg.get("input_shape_per_replica")
    if isinstance(shape, list) and len(shape) >= 3:
        args.per_replica_batch_size = int(shape[0])
        args.input_size = int(shape[1])
    else:
        args.per_replica_batch_size = int(
            cfg.get("per_replica_batch_size", args.per_replica_batch_size)
        )

    # The original training config may contain L1 admission fields such as
    # max_steps=5. This reduced differential reproducer keeps 20 steps unless a
    # standalone config explicitly provides steps_per_arm.
    if "steps_per_arm" in cfg:
        args.steps = int(cfg["steps_per_arm"])

    communication = str(
        cfg.get("communication")
        or cfg.get("communication_implementation")
        or args.communication
    ).upper()
    if communication in {"RING", "AUTO"}:
        args.communication = communication

    return cfg


def config_summary(args: argparse.Namespace, cfg: dict[str, Any]) -> dict[str, Any]:
    if not cfg:
        return {}
    return {
        "config_file": Path(str(args.config)).name,
        "device": cfg.get("device"),
        "distributed_strategy": cfg.get("distributed_strategy"),
        "tf_strategy_class": cfg.get("tf_strategy_class"),
        "communication_implementation": (
            cfg.get("communication_implementation") or cfg.get("communication")
        ),
        "precision_policy": cfg.get("precision_policy"),
        "optimizer": cfg.get("optimizer"),
        "seed": cfg.get("seed"),
        "per_replica_batch_size": cfg.get("per_replica_batch_size"),
        "input_shape_per_replica": cfg.get("input_shape_per_replica"),
        "learning_rate": cfg.get("learning_rate"),
        "weight_decay": cfg.get("weight_decay"),
        "steps_per_arm": cfg.get("steps_per_arm"),
        "run_eagerly": cfg.get("run_eagerly"),
    }


def runtime_configuration(args: argparse.Namespace, workers: list[str]) -> dict[str, Any]:
    return {
        "strategy": "MultiWorkerMirroredStrategy",
        "workers": workers,
        "num_workers": len(workers),
        "gpus_per_worker": 1,
        "communication": str(args.communication),
        "model": "ResNet50(weights=None, include_top=False, pooling='avg') + Dense(10)",
        "input": "synthetic normal tensor generated from a fixed seed",
        "steps_per_arm": int(args.steps),
        "per_replica_batch_size": int(args.per_replica_batch_size),
        "effective_global_batch_size": int(args.per_replica_batch_size) * len(workers),
        "input_shape_per_replica": [
            int(args.per_replica_batch_size),
            int(args.input_size),
            int(args.input_size),
            3,
        ],
        "num_classes": int(args.num_classes),
        "seed": int(args.seed),
        "optimizer": "Adam",
        "learning_rate": float(args.learning_rate),
        "weight_decay": float(args.weight_decay),
        "precision_policy": str(args.precision_policy),
        "threading": {
            "intra_op_parallelism_threads": int(args.intra_op_threads),
            "inter_op_parallelism_threads": int(args.inter_op_threads),
        },
        "arms": [
            {"name": "nonxla", "forward_backward_jit_compile": False},
            {"name": "xla", "forward_backward_jit_compile": True},
        ],
        "optimizer_step_jit_compile": False,
        "comparison_thresholds": {
            "loss_atol": LOSS_ATOL,
            "loss_rtol": LOSS_RTOL,
            "grad_atol": GRAD_ATOL,
            "grad_rtol": GRAD_RTOL,
        },
    }


def setup_environment(args: argparse.Namespace) -> list[str]:
    workers = [part.strip() for part in args.worker_addresses.split(",") if part.strip()]
    if len(workers) != 2:
        raise SystemExit("--worker-addresses must contain exactly two host:port entries")
    if args.worker_index not in (0, 1):
        raise SystemExit("--worker-index must be 0 or 1")

    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.cuda_visible_devices)
    os.environ.setdefault("TF_FORCE_GPU_ALLOW_GROWTH", "true")
    os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "1")
    os.environ["TF_CONFIG"] = json.dumps(
        {
            "cluster": {"worker": workers},
            "task": {"type": "worker", "index": int(args.worker_index)},
        },
        sort_keys=True,
    )
    return workers


def make_model(tf: Any, input_size: int, num_classes: int) -> Any:
    base = tf.keras.applications.ResNet50(
        weights=None,
        include_top=False,
        input_shape=(input_size, input_size, 3),
        pooling="avg",
    )
    out = tf.keras.layers.Dense(num_classes)(base.output)
    return tf.keras.Model(base.input, out)


def run_arm(tf: Any, strategy: Any, args: argparse.Namespace, *, jit_compile: bool) -> dict[str, float]:
    tf.keras.utils.set_random_seed(int(args.seed))
    tf.keras.mixed_precision.set_global_policy(str(args.precision_policy))

    with strategy.scope():
        model = make_model(tf, int(args.input_size), int(args.num_classes))
        optimizer = tf.keras.optimizers.Adam(
            learning_rate=float(args.learning_rate),
            weight_decay=float(args.weight_decay),
        )
        loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
            from_logits=True,
            reduction=tf.keras.losses.Reduction.NONE,
        )

    def grad_fn(x, y):
        with tf.GradientTape() as tape:
            logits = model(x, training=True)
            per_example_loss = loss_fn(y, logits)
            loss = tf.nn.compute_average_loss(per_example_loss)
        grads = tape.gradient(loss, model.trainable_variables)
        grad_norm = tf.linalg.global_norm([g for g in grads if g is not None])
        return loss, grads, grad_norm

    compiled_grad_fn = tf.function(grad_fn, jit_compile=bool(jit_compile))

    def step_fn(x, y):
        loss, grads, grad_norm = compiled_grad_fn(x, y)
        optimizer.apply_gradients(
            (g, v) for g, v in zip(grads, model.trainable_variables) if g is not None
        )
        return loss, grad_norm

    final_loss = float("nan")
    final_grad_norm = float("nan")
    batch = int(args.per_replica_batch_size)
    input_size = int(args.input_size)

    for step in range(int(args.steps)):
        generator = tf.random.Generator.from_seed(int(args.seed) + step)
        x = generator.normal((batch, input_size, input_size, 3))
        y = tf.range(batch, dtype=tf.int32) % int(args.num_classes)
        per_replica_loss, per_replica_grad_norm = strategy.run(step_fn, args=(x, y))
        final_loss = float(
            strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_loss, axis=None).numpy()
        )
        final_grad_norm = float(
            strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_grad_norm, axis=None).numpy()
        )

    return {"loss": final_loss, "grad_norm": final_grad_norm}


def compare(nonxla: dict[str, float], xla: dict[str, float]) -> dict[str, Any]:
    loss_abs = abs(nonxla["loss"] - xla["loss"])
    grad_abs = abs(nonxla["grad_norm"] - xla["grad_norm"])
    loss_tol = LOSS_ATOL + LOSS_RTOL * max(abs(nonxla["loss"]), abs(xla["loss"]))
    grad_tol = GRAD_ATOL + GRAD_RTOL * max(abs(nonxla["grad_norm"]), abs(xla["grad_norm"]))
    loss_div = math.isfinite(loss_abs) and loss_abs > loss_tol
    grad_div = math.isfinite(grad_abs) and grad_abs > grad_tol
    return {
        "divergent": bool(loss_div or grad_div),
        "nonxla_loss": nonxla["loss"],
        "xla_loss": xla["loss"],
        "loss_abs_diff": loss_abs,
        "loss_tolerance": loss_tol,
        "nonxla_grad_norm": nonxla["grad_norm"],
        "xla_grad_norm": xla["grad_norm"],
        "grad_abs_diff": grad_abs,
        "grad_tolerance": grad_tol,
    }


def main() -> int:
    args = parse_args()
    source_config = apply_config(args)
    workers = setup_environment(args)

    import tensorflow as tf

    try:
        tf.config.threading.set_intra_op_parallelism_threads(int(args.intra_op_threads))
        tf.config.threading.set_inter_op_parallelism_threads(int(args.inter_op_threads))
    except RuntimeError:
        pass

    gpus = tf.config.list_physical_devices("GPU")
    if not gpus:
        raise RuntimeError("This reproduction expects one visible GPU per worker.")
    for gpu in gpus:
        try:
            tf.config.experimental.set_memory_growth(gpu, True)
        except RuntimeError:
            pass

    communication_impl = (
        tf.distribute.experimental.CommunicationImplementation.RING
        if str(args.communication).upper() == "RING"
        else tf.distribute.experimental.CommunicationImplementation.AUTO
    )
    strategy = tf.distribute.MultiWorkerMirroredStrategy(
        communication_options=tf.distribute.experimental.CommunicationOptions(
            implementation=communication_impl
        )
    )

    print(
        json.dumps(
            {
                "event": "environment",
                "worker_index": int(args.worker_index),
                "tensorflow_version": tf.__version__,
                "gpus": [str(gpu) for gpu in gpus],
                "tf_config": json.loads(os.environ["TF_CONFIG"]),
                "workers": workers,
                "runtime_configuration": runtime_configuration(args, workers),
                "source_config": config_summary(args, source_config),
            },
            sort_keys=True,
        ),
        flush=True,
    )

    started = time.perf_counter()
    nonxla = run_arm(tf, strategy, args, jit_compile=False)
    print(
        json.dumps(
            {"event": "arm", "worker_index": int(args.worker_index), "arm": "nonxla", **nonxla},
            sort_keys=True,
        ),
        flush=True,
    )

    xla = run_arm(tf, strategy, args, jit_compile=True)
    print(
        json.dumps(
            {"event": "arm", "worker_index": int(args.worker_index), "arm": "xla", **xla},
            sort_keys=True,
        ),
        flush=True,
    )

    result = compare(nonxla, xla)
    print(
        json.dumps(
            {
                "event": "summary",
                "worker_index": int(args.worker_index),
                "status": "DIVERGENCE_REPRODUCED" if result["divergent"] else "not_reproduced",
                "wall_time_sec": round(time.perf_counter() - started, 3),
                **result,
            },
            sort_keys=True,
        ),
        flush=True,
    )
    if args.fail_on_divergence and result["divergent"]:
        return 2
    return 0


if __name__ == "__main__":
    raise SystemExit(main())



## tensorflow_official_repro_config.json

{
  "arms": [
    {
      "forward_backward_jit_compile": false,
      "name": "nonxla"
    },
    {
      "forward_backward_jit_compile": true,
      "name": "xla"
    }
  ],
  "communication_implementation": "AUTO",
  "comparison_thresholds": {
    "grad_atol": 1.0,
    "grad_rtol": 0.2,
    "loss_atol": 0.15,
    "loss_rtol": 0.1
  },
  "device": "gpu",
  "distributed_strategy": "multi_worker_mirrored",
  "effective_global_batch_size": 4,
  "framework": "tensorflow",
  "input": "deterministic synthetic normal tensor",
  "input_shape_per_replica": [
    2,
    224,
    224,
    3
  ],
  "inter_op_parallelism_threads": 1,
  "intra_op_parallelism_threads": 2,
  "learning_rate": 0.1,
  "local_device_count": 1,
  "model": "resnet50",
  "num_classes": 10,
  "num_replicas_in_sync": 2,
  "num_workers": 2,
  "optimizer": "adam",
  "optimizer_step_jit_compile": false,
  "per_replica_batch_size": 2,
  "precision_policy": "float32",
  "run_eagerly": false,
  "seed": 42,
  "steps_per_arm": 20,
  "tf_strategy_class": "MultiWorkerMirroredStrategy",
  "weight_decay": 0.01,
  "worker_addresses": [
    "10.60.88.171:61120",
    "10.60.210.23:62120"
  ]
}

Relevant log output

| worker | non-XLA loss | XLA loss | loss abs diff | non-XLA grad norm | XLA grad norm | grad abs diff |
|---|---:|---:|---:|---:|---:|---:|
| worker 0 | 0.0021458889823406935 | 1.6746044158935547 | 1.672458526911214 | 0.02090447209775448 | 1.9025657176971436 | 1.881661245599389 |
| worker 1 | 0.0021458889823406935 | 1.6746044158935547 | 1.672458526911214 | 0.02090447209775448 | 1.9025657176971436 | 1.881661245599389 |

Metadata

Metadata

Assignees

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions