Skip to content

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argument for argument mat1 in method wrapper_CUDA_addmm) #34695

@ra-MANUJ-an

Description

@ra-MANUJ-an

Reproduction

I am trying to finetune Qwen2-0.5B model on some training data using a multi-GPU setup. The same code (given further below) seems to work in a single-GPU setting (when i set CUDA_VISIBLE_DEVICES=0):

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[18], line 4
      2 import torch
      3 torch.autograd.set_detect_anomaly(True)
----> 4 main()

Cell In[14], line 15, in main()
      8 trainer = Trainer(env_params=env_params,
      9                   model_params=model_params,
     10                   optimizer_params=optimizer_params,
     11                   trainer_params=trainer_params)
     13 copy_all_src(trainer.result_folder)
---> 15 trainer.run()

File ~/neuralcombinatorialoptimization/NCO-master/NEW_py_ver/TSP/POMO/TSPTrainerTransformer.py:92, in TSPTrainer.run(self)
     89 self.scheduler.step()
     91 # Train
---> 92 train_score, train_loss = self._train_one_epoch(epoch)
     93 self.result_log.append('train_score', epoch, train_score)
     94 self.result_log.append('train_loss', epoch, train_loss)

File ~/neuralcombinatorialoptimization/NCO-master/NEW_py_ver/TSP/POMO/TSPTrainerTransformer.py:151, in TSPTrainer._train_one_epoch(self, epoch)
    148 remaining = train_num_episode - episode
    149 batch_size = min(self.trainer_params['train_batch_size'], remaining)
--> 151 avg_score, avg_loss = self._train_one_batch(batch_size)
    152 score_AM.update(avg_score, batch_size)
    153 loss_AM.update(avg_loss, batch_size)

File ~/neuralcombinatorialoptimization/NCO-master/NEW_py_ver/TSP/POMO/TSPTrainerTransformer.py:193, in TSPTrainer._train_one_batch(self, batch_size)
    191 state, reward, done = self.env.pre_step()
    192 while not done:
--> 193     selected, prob = self.model.module(state)
    194     # shape: (batch, pomo)
    195     state, reward, done = self.env.step(selected)

File ~/second/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/second/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/neuralcombinatorialoptimization/NCO-master/NEW_py_ver/TSP/POMO/TSPTransformerModelQuant_b.py:39, in TSPTransformer.forward(self, state)
     37     return self._init_sequence(batch_size, pomo_size)
     38 else:
---> 39     return self._continue_sequence(state, batch_size, pomo_size)

File ~/neuralcombinatorialoptimization/NCO-master/NEW_py_ver/TSP/POMO/TSPTransformerModelQuant_b.py:84, in TSPTransformer._continue_sequence(self, state, batch_size, pomo_size)
     81 state.ninf_mask = state.ninf_mask.to(self.device)
     83 # Get probabilities from decoder
---> 84 probs = self.decoder(self.seq_so_far, self.input_mask, state.ninf_mask)
     86 # Select next node
     87 if self.training or self.model_params['eval_type'] == 'softmax':

File ~/second/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/second/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/neuralcombinatorialoptimization/NCO-master/NEW_py_ver/TSP/POMO/TSPTransformerModelQuant_b.py:185, in Decoder.forward(self, seq_so_far, inp_mask, ninf_mask)
    182 flat_mask = inp_mask.reshape(batch_size * pomo_size, problem_size)
    184 # Get model outputs
--> 185 outputs = self.model(inputs_embeds=flat_seq, attention_mask=flat_mask)
    186 logits = outputs.logits
    188 # Get last valid position

File ~/second/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/second/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/second/lib/python3.10/site-packages/peft/peft_model.py:1644, in PeftModelForCausalLM.forward(self, input_ids, attention_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict, task_ids, **kwargs)
   1642     with self._enable_peft_forward_hooks(**kwargs):
   1643         kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
-> 1644         return self.base_model(
   1645             input_ids=input_ids,
   1646             attention_mask=attention_mask,
   1647             inputs_embeds=inputs_embeds,
   1648             labels=labels,
   1649             output_attentions=output_attentions,
   1650             output_hidden_states=output_hidden_states,
   1651             return_dict=return_dict,
   1652             **kwargs,
   1653         )
   1655 batch_size = _get_batch_size(input_ids, inputs_embeds)
   1656 if attention_mask is not None:
   1657     # concat prompt attention mask

File ~/second/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/second/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/second/lib/python3.10/site-packages/peft/tuners/tuners_utils.py:197, in BaseTuner.forward(self, *args, **kwargs)
    196 def forward(self, *args: Any, **kwargs: Any):
--> 197     return self.model.forward(*args, **kwargs)

File ~/second/lib/python3.10/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    168         output = module._old_forward(*args, **kwargs)
    169 else:
--> 170     output = module._old_forward(*args, **kwargs)
    171 return module._hf_hook.post_forward(module, output)

File ~/second/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py:1170, in Qwen2ForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, num_logits_to_keep, **loss_kwargs)
   1167 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
   1169 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> 1170 outputs = self.model(
   1171     input_ids=input_ids,
   1172     attention_mask=attention_mask,
   1173     position_ids=position_ids,
   1174     past_key_values=past_key_values,
   1175     inputs_embeds=inputs_embeds,
   1176     use_cache=use_cache,
   1177     output_attentions=output_attentions,
   1178     output_hidden_states=output_hidden_states,
   1179     return_dict=return_dict,
   1180     cache_position=cache_position,
   1181 )
   1183 hidden_states = outputs[0]
   1184 # Only compute necessary logits, and do not upcast them to float if we are not computing the loss

File ~/second/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/second/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/second/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py:901, in Qwen2Model.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
    889     layer_outputs = self._gradient_checkpointing_func(
    890         decoder_layer.__call__,
    891         hidden_states,
   (...)
    898         position_embeddings,
    899     )
    900 else:
--> 901     layer_outputs = decoder_layer(
    902         hidden_states,
    903         attention_mask=causal_mask,
    904         position_ids=position_ids,
    905         past_key_value=past_key_values,
    906         output_attentions=output_attentions,
    907         use_cache=use_cache,
    908         cache_position=cache_position,
    909         position_embeddings=position_embeddings,
    910     )
    912 hidden_states = layer_outputs[0]
    914 if use_cache:

File ~/second/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/second/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/second/lib/python3.10/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    168         output = module._old_forward(*args, **kwargs)
    169 else:
--> 170     output = module._old_forward(*args, **kwargs)
    171 return module._hf_hook.post_forward(module, output)

File ~/second/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py:629, in Qwen2DecoderLayer.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, position_embeddings, **kwargs)
    626 hidden_states = self.input_layernorm(hidden_states)
    628 # Self Attention
--> 629 hidden_states, self_attn_weights, present_key_value = self.self_attn(
    630     hidden_states=hidden_states,
    631     attention_mask=attention_mask,
    632     position_ids=position_ids,
    633     past_key_value=past_key_value,
    634     output_attentions=output_attentions,
    635     use_cache=use_cache,
    636     cache_position=cache_position,
    637     position_embeddings=position_embeddings,
    638 )
    639 hidden_states = residual + hidden_states
    641 # Fully Connected

File ~/second/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/second/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/second/lib/python3.10/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    168         output = module._old_forward(*args, **kwargs)
    169 else:
--> 170     output = module._old_forward(*args, **kwargs)
    171 return module._hf_hook.post_forward(module, output)

File ~/second/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py:506, in Qwen2SdpaAttention.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, position_embeddings)
    495     return super().forward(
    496         hidden_states=hidden_states,
    497         attention_mask=attention_mask,
   (...)
    501         use_cache=use_cache,
    502     )
    504 bsz, q_len, _ = hidden_states.size()
--> 506 query_states = self.q_proj(hidden_states)
    507 key_states = self.k_proj(hidden_states)
    508 value_states = self.v_proj(hidden_states)

File ~/second/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/second/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/second/lib/python3.10/site-packages/peft/tuners/lora/layer.py:572, in Linear.forward(self, x, *args, **kwargs)
    570     result = self.base_layer(x, *args, **kwargs)
    571 else:
--> 572     result = self.base_layer(x, *args, **kwargs)
    573     torch_result_dtype = result.dtype
    574     for active_adapter in self.active_adapters:

File ~/second/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/second/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/second/lib/python3.10/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    168         output = module._old_forward(*args, **kwargs)
    169 else:
--> 170     output = module._old_forward(*args, **kwargs)
    171 return module._hf_hook.post_forward(module, output)

File ~/second/lib/python3.10/site-packages/torch/nn/modules/linear.py:125, in Linear.forward(self, input)
    124 def forward(self, input: Tensor) -> Tensor:
--> 125     return F.linear(input, self.weight, self.bias)

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)

Code for the above error is given below:

Trainer.py

import torch
from logging import getLogger
from torch.nn.parallel import DataParallel
from TSPEnvQuant import TSPEnv as Env
from TSPTransformerModelQuant_b import TSPTransformer as Model

from torch.optim import Adam as Optimizer
from torch.optim.lr_scheduler import MultiStepLR as Scheduler

from utils.utils import *


class TSPTrainer:
    def __init__(self,
                 env_params,
                 model_params,
                 optimizer_params,
                 trainer_params):

        # save arguments
        self.env_params = env_params
        self.model_params = model_params
        self.optimizer_params = optimizer_params
        self.trainer_params = trainer_params

        # result folder, logger
        self.logger = getLogger(name='trainer')
        self.result_folder = get_result_folder()
        self.result_log = LogData()

        # cuda
        USE_CUDA = self.trainer_params['use_cuda']
        if USE_CUDA:
            cuda_device_num = self.trainer_params['cuda_device_num']
            torch.cuda.set_device(cuda_device_num)
            device = torch.device('cuda', cuda_device_num)
            torch.set_default_tensor_type('torch.cuda.FloatTensor')
        else:
            device = torch.device('cpu')
            torch.set_default_tensor_type('torch.FloatTensor')

        # Main Components
        self.model = Model(**self.model_params)
        if USE_CUDA and torch.cuda.device_count() > 1:
            self.logger.info(f"Using {torch.cuda.device_count()} GPUs!")
            self.model = DataParallel(self.model)
        self.model = self.model.to(device)

        self.env = Env(**self.env_params)
        self.optimizer = Optimizer(self.model.parameters(), **self.optimizer_params['optimizer'])
        self.scheduler = Scheduler(self.optimizer, **self.optimizer_params['scheduler'])

        # Restore
        self.start_epoch = 1
        model_load = trainer_params['model_load']
        if model_load['enable']:
            checkpoint_fullname = '{path}/checkpoint-{epoch}.pt'.format(**model_load)
            checkpoint = torch.load(checkpoint_fullname, map_location=device)
            # Handle loading state dict for DataParallel
            if isinstance(self.model, DataParallel):
                # If saved model wasn't using DataParallel but current model is
                if not any(key.startswith('module.') for key in checkpoint['model_state_dict'].keys()):
                    new_state_dict = {'module.' + k: v for k, v in checkpoint['model_state_dict'].items()}
                    self.model.load_state_dict(new_state_dict)
                else:
                    self.model.load_state_dict(checkpoint['model_state_dict'])
            else:
                # If saved model was using DataParallel but current model isn't
                if any(key.startswith('module.') for key in checkpoint['model_state_dict'].keys()):
                    new_state_dict = {k.replace('module.', ''): v for k, v in checkpoint['model_state_dict'].items()}
                    self.model.load_state_dict(new_state_dict)
                else:
                    self.model.load_state_dict(checkpoint['model_state_dict'])
            self.start_epoch = 1 + model_load['epoch']
            self.result_log.set_raw_data(checkpoint['result_log'])
            self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            self.scheduler.last_epoch = model_load['epoch']-1
            self.logger.info('Saved Model Loaded !!')

        # utility
        self.time_estimator = TimeEstimator()

    def run(self):
        self.time_estimator.reset(self.start_epoch)
        for epoch in range(self.start_epoch, self.trainer_params['epochs']+1):
            self.logger.info('=================================================================')

            # LR Decay
            self.scheduler.step()

            # Train
            train_score, train_loss = self._train_one_epoch(epoch)
            self.result_log.append('train_score', epoch, train_score)
            self.result_log.append('train_loss', epoch, train_loss)

            ############################
            # Logs & Checkpoint
            ############################
            elapsed_time_str, remain_time_str = self.time_estimator.get_est_string(epoch, self.trainer_params['epochs'])
            self.logger.info("Epoch {:3d}/{:3d}: Time Est.: Elapsed[{}], Remain[{}]".format(
                epoch, self.trainer_params['epochs'], elapsed_time_str, remain_time_str))

            all_done = (epoch == self.trainer_params['epochs'])
            model_save_interval = self.trainer_params['logging']['model_save_interval']
            img_save_interval = self.trainer_params['logging']['img_save_interval']

            if epoch > 1:  # save latest images, every epoch
                self.logger.info("Saving log_image")
                image_prefix = '{}/latest'.format(self.result_folder)
                util_save_log_image_with_label(image_prefix, self.trainer_params['logging']['log_image_params_1'],
                                    self.result_log, labels=['train_score'])
                util_save_log_image_with_label(image_prefix, self.trainer_params['logging']['log_image_params_2'],
                                    self.result_log, labels=['train_loss'])

            if all_done or (epoch % model_save_interval) == 0:
                self.logger.info("Saving trained_model")
                checkpoint_dict = {
                    'epoch': epoch,
                    'model_state_dict': self.model.state_dict(),
                    'optimizer_state_dict': self.optimizer.state_dict(),
                    'scheduler_state_dict': self.scheduler.state_dict(),
                    'result_log': self.result_log.get_raw_data()
                }
                torch.save(checkpoint_dict, '{}/checkpoint-{}.pt'.format(self.result_folder, epoch))

            if all_done or (epoch % img_save_interval) == 0:
                image_prefix = '{}/img/checkpoint-{}'.format(self.result_folder, epoch)
                util_save_log_image_with_label(image_prefix, self.trainer_params['logging']['log_image_params_1'],
                                    self.result_log, labels=['train_score'])
                util_save_log_image_with_label(image_prefix, self.trainer_params['logging']['log_image_params_2'],
                                    self.result_log, labels=['train_loss'])

            if all_done:
                self.logger.info(" *** Training Done *** ")
                self.logger.info("Now, printing log array...")
                util_print_log_array(self.logger, self.result_log)

    def _train_one_epoch(self, epoch):

        score_AM = AverageMeter()
        loss_AM = AverageMeter()

        train_num_episode = self.trainer_params['train_episodes']
        episode = 0
        loop_cnt = 0
        while episode < train_num_episode:

            remaining = train_num_episode - episode
            batch_size = min(self.trainer_params['train_batch_size'], remaining)

            avg_score, avg_loss = self._train_one_batch(batch_size)
            score_AM.update(avg_score, batch_size)
            loss_AM.update(avg_loss, batch_size)

            episode += batch_size

            # Log First 10 Batch, only at the first epoch
            if epoch == self.start_epoch:
                loop_cnt += 1
                if loop_cnt <= 10:
                    self.logger.info('Epoch {:3d}: Train {:3d}/{:3d}({:1.1f}%)  Score: {:.4f},  Loss: {:.4f}'
                                     .format(epoch, episode, train_num_episode, 100. * episode / train_num_episode,
                                             score_AM.avg, loss_AM.avg))

        # Log Once, for each epoch
        self.logger.info('Epoch {:3d}: Train ({:3.0f}%)  Score: {:.4f},  Loss: {:.4f}'
                         .format(epoch, 100. * episode / train_num_episode,
                                 score_AM.avg, loss_AM.avg))

        return score_AM.avg, loss_AM.avg

    def _train_one_batch(self, batch_size):

        # Prep
        ###############################################
        self.model.train()
        self.env.load_problems(batch_size)
        reset_state, _, _ = self.env.reset()
        # Handle pre_forward for DataParallel
        if isinstance(self.model, DataParallel):
            print("Is DataParallel")
            self.model.module.pre_forward(reset_state)
        else:
            self.model.pre_forward(reset_state)

        prob_list = torch.zeros(size=(batch_size, self.env.pomo_size, 0))
        # shape: (batch, pomo, 0~problem)

        # POMO Rollout
        ###############################################
        state, reward, done = self.env.pre_step()
        while not done:
            selected, prob = self.model.module(state)
            # shape: (batch, pomo)
            state, reward, done = self.env.step(selected)
            prob_list = torch.cat((prob_list, prob[:, :, None]), dim=2)

        # Loss
        ###############################################
        advantage = reward - reward.float().mean(dim=1, keepdims=True)
        # shape: (batch, pomo)
        log_prob = prob_list.log().sum(dim=2)
        # size = (batch, pomo)
        loss = -advantage * log_prob  # Minus Sign: To Increase REWARD
        # shape: (batch, pomo)
        loss_mean = loss.mean()

        # Score
        ###############################################
        max_pomo_reward, _ = reward.max(dim=1)  # get best results from pomo
        score_mean = -max_pomo_reward.float().mean()  # negative sign to make positive value

        # Step & Return
        ###############################################
        self.model.zero_grad()
        loss_mean.backward()
        self.optimizer.step()
        return score_mean.item(), loss_mean.item()
Model.py

import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from transformers import AutoConfig
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, PeftModel, TaskType
from typing import Optional, Dict, Any, Tuple

class TSPTransformer(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
        self.model_params = kwargs
        self.encoder = Encoder(**kwargs)
        self.embedding_size = kwargs.get('embedding_dim', 896)
        
        # Load the model with LoRA and 4-bit quantization if needed
        self.model = load_model(kwargs)
        self.decoder = Decoder(self.model, **kwargs)
        
        # Initialize state storage
        self.encoded_nodes = None
        self.seq_so_far = None
        self.input_mask = None
        self.t = None
        self.device = kwargs.get('device', torch.device("cuda" if torch.cuda.is_available() else "cpu"))

    def pre_forward(self, reset_state):
        """Initialize model state for new sequence"""
        self.encoded_nodes = self.encoder(reset_state.problems)
        self.problem_size = reset_state.problems.size(1)
        self.batch_size = reset_state.problems.size(0)

    def forward(self, state) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        batch_size = state.BATCH_IDX.size(0)
        pomo_size = state.BATCH_IDX.size(1)

        if state.current_node is None:
            return self._init_sequence(batch_size, pomo_size)
        else:
            return self._continue_sequence(state, batch_size, pomo_size)

    def _init_sequence(self, batch_size: int, pomo_size: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """Initialize sequence state"""
        self.t = 0  # Start at 0 instead of -1
        
        # Create new tensors instead of modifying in place
        selected = torch.arange(pomo_size, device=self.device).expand(batch_size, pomo_size)
        prob = torch.ones(size=(batch_size, pomo_size), device=self.device)
        
        # Initialize sequence storage with proper dimensions
        self.seq_so_far = torch.zeros(
            (batch_size, pomo_size, self.problem_size, self.embedding_size),
            device=self.device
        )
        
        self.input_mask = torch.zeros(
            (batch_size, pomo_size, self.problem_size),
            dtype=torch.bool,
            device=self.device
        )
        
        return selected, prob

    def _continue_sequence(self, state, batch_size: int, pomo_size: int) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        """Continue sequence generation"""
        # Get encoded representation of current node
        encoded_current = self._get_encoded_node(state.current_node)
        
        # Create new tensor for updated sequence
        new_seq = self.seq_so_far.clone()
        new_seq[:, :, self.t, :] = encoded_current
        self.seq_so_far = new_seq
        
        # Create new tensor for updated mask
        new_mask = self.input_mask.clone()
        new_mask[:, :, self.t] = True
        self.input_mask = new_mask
        
        # Move tensors to correct device
        self.seq_so_far = self.seq_so_far.to(self.device)
        self.input_mask = self.input_mask.to(self.device)
        state.ninf_mask = state.ninf_mask.to(self.device)
        
        # Get probabilities from decoder
        probs = self.decoder(self.seq_so_far, self.input_mask, state.ninf_mask)
        
        # Select next node
        if self.training or self.model_params['eval_type'] == 'softmax':
            selected, prob = self._sample_node(probs, state, batch_size, pomo_size)
        else:
            selected = probs.argmax(dim=2)
            prob = None
        
        self.t += 1
        return selected, prob

    def _get_encoded_node(self, node_indices: torch.Tensor) -> torch.Tensor:
        """Get encoded representation of nodes safely"""
        batch_size, pomo_size = node_indices.shape
        embedding_dim = self.encoded_nodes.size(2)
        
        # Create gathering indices
        gather_idx = node_indices[:, :, None].expand(batch_size, pomo_size, embedding_dim)
        gather_idx = gather_idx.to(self.encoded_nodes.device)
        
        # Gather encoded representations
        return self.encoded_nodes.gather(dim=1, index=gather_idx)

    def _sample_node(self, probs: torch.Tensor, state, batch_size: int, pomo_size: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """Sample next node with retry logic"""
        max_attempts = 100
        for _ in range(max_attempts):
            # Reshape for sampling
            flat_probs = probs.reshape(batch_size * pomo_size, -1)
            
            # Sample indices
            selected = flat_probs.multinomial(1, replacement=True)
            selected = selected.reshape(batch_size, pomo_size)
            
            # Calculate probabilities
            prob = probs[state.BATCH_IDX, state.POMO_IDX, selected]
            prob = prob.reshape(batch_size, pomo_size)
            
            if (prob > 0).all():
                return selected, prob
        
        raise RuntimeError(f"Failed to sample valid nodes after {max_attempts} attempts")

class Encoder(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
        self.embedding_dim = kwargs.get('embedding_dim', 896) - 1
        self.embed_layer = nn.Linear(2, self.embedding_dim)
        self.device = kwargs.get('device', torch.device("cuda" if torch.cuda.is_available() else "cpu"))
    
    def forward(self, problems: torch.Tensor) -> torch.Tensor:
        batch_size, problem_size = problems.shape[:2]
        
        # Create position encodings
        ids = torch.arange(problem_size, device=self.device).expand(batch_size, problem_size)
        
        # Embed coordinates
        embedded = self.embed_layer(problems.reshape(-1, 2))
        embedded = embedded.reshape(batch_size, problem_size, self.embedding_dim)
        
        # Concatenate position encodings
        return torch.cat([ids.unsqueeze(-1).float(), embedded], dim=-1)

class Decoder(nn.Module):
    def __init__(self, model: nn.Module, **kwargs):
        super().__init__()
        self.model = model
        self.problem_size = kwargs.get('problem_size', 20)
        self.use_lora = kwargs.get('use_lora', True)
        
        self._setup_model()
    
    def _setup_model(self):
        """Configure model architecture"""
        # Modify output size
        self.model.lm_head = nn.Linear(
            self.model.config.hidden_size,
            self.problem_size
        ).to(self.model.device)
        
        # Apply LoRA if requested
        if self.use_lora:
            lora_config = LoraConfig(
                r=4,
                lora_alpha=32,
                target_modules=["q_proj", "v_proj"],
                lora_dropout=0.1,
                bias="none",
                task_type=TaskType.CAUSAL_LM
            )
            self.model = get_peft_model(self.model, lora_config)
    
    def forward(self, seq_so_far: torch.Tensor, inp_mask: torch.Tensor, ninf_mask: torch.Tensor) -> torch.Tensor:
        batch_size, pomo_size, problem_size, embedding_dim = seq_so_far.shape
        
        # Reshape inputs
        flat_seq = seq_so_far.reshape(batch_size * pomo_size, problem_size, embedding_dim)
        flat_mask = inp_mask.reshape(batch_size * pomo_size, problem_size)
        
        # Get model outputs
        outputs = self.model(inputs_embeds=flat_seq, attention_mask=flat_mask)
        logits = outputs.logits
        
        # Get last valid position
        last_positions = flat_mask.sum(dim=1).long() - 1
        
        # Gather logits for last positions
        batch_indices = torch.arange(batch_size * pomo_size, device=logits.device)
        gathered_logits = logits[batch_indices, last_positions]
        
        # Reshape and apply mask
        logits = gathered_logits.reshape(batch_size, pomo_size, problem_size)
        masked_logits = logits + ninf_mask.float()
        
        # Return probabilities
        return torch.softmax(masked_logits, dim=2)

def load_model(config: Dict[str, Any]) -> nn.Module:
    """Load model with proper configuration"""
    # print(config)
    device = config.get('device', torch.device("cuda" if torch.cuda.is_available() else "cpu"))
    
    if config.get('checkpoint_path'):
        # print('checkpoint_path')
        try:
            return PeftModel.from_pretrained(
                config['model_name'],
                config['checkpoint_path'],
                is_trainable=True
            ).to(device)
        except Exception as e:
            print(f"Error loading checkpoint: {e}")
            print("Falling back to base model...")
    
    print(config)
    # print(config['use_4bit'])
    if config.get('use_4bit', True):
        print('use_4bit')
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.bfloat16,
            llm_int8_has_fp16_weight=False,
            bnb_4bit_quant_type="nf4",
            llm_int8_threshold=6.0,
            bnb_4bit_use_double_quant=True,
        )
        # print(config['model_name'])
        # print(type(config['model_name']))
        model = AutoModelForCausalLM.from_pretrained(
            config['model_name'],
            trust_remote_code=True,
            device_map="auto",
            torch_dtype=torch.bfloat16,
            quantization_config=bnb_config
        )
        model = prepare_model_for_kbit_training(model)
        model.config.use_cache = False
    else:
        # print('else')
        model = AutoModelForCausalLM.from_pretrained(
            config['model_name'],
            torch_dtype=torch.float32,
            trust_remote_code=True,
            device_map="auto",
        ).to(device)
    
    return model

Expected behavior
Expected behavior is that the model should train in a multi-GPU setting without throwing any errors. The same script works in single-GPU setting but throws the above error in a multi-GPU setting

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions