Skip to content

Commit

Permalink
Implement CollabCollate
Browse files Browse the repository at this point in the history
  • Loading branch information
alex-snd committed Sep 1, 2022
1 parent 4939a4c commit 8f2cbb3
Show file tree
Hide file tree
Showing 8 changed files with 155 additions and 42 deletions.
2 changes: 1 addition & 1 deletion hubconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,6 @@ def collab_args() -> Dict[str, Any]:
'experiment_prefix': 'trecover',
'target_batch_size': 4096,
'min_noise': 0,
'max_noise': 1,
'max_noise': 1, # TODO model params

}
23 changes: 16 additions & 7 deletions src/trecover/train/collab/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


def sync_base_args(args: Namespace) -> Namespace:
if args.no_args_sync:
if not args.sync_args:
return args

base_args = torch.hub.load('alex-snd/TRecover', 'collab_args', force_reload=True, verbose=False)
Expand All @@ -29,6 +29,17 @@ def sync_base_args(args: Namespace) -> Namespace:
return args


def get_sync_parser(add_help: bool = True) -> ArgumentParser:
parser = ArgumentParser('Synchronization arguments', add_help=add_help)

parser.add_argument('--sync-period', default=5, type=int,
help='Period (in collaborative steps) for arguments resynchronization')
parser.add_argument('--sync-args', action='store_true',
help='Sync base collaborative arguments with torch.hub')

return parser


def get_model_parser(add_help: bool = True) -> ArgumentParser:
parser = ArgumentParser('Model arguments', add_help=add_help)

Expand Down Expand Up @@ -123,7 +134,7 @@ def get_optimization_parser(add_help: bool = True) -> ArgumentParser:
help='Batch size that fits into accelerator memory')
parser.add_argument('--accumulate-batches', default=1, type=int,
help='Number of steps for gradients accumulation')
parser.add_argument('--target-batch-size', default=4096, type=int,
parser.add_argument('--target-batch-size', default=2048, type=int,
help='Perform optimizer step after all peers collectively accumulate this many samples')
parser.add_argument('--matchmaking-time', default=50, type=float,
help='Averaging group will wait for stragglers for at most this many seconds')
Expand Down Expand Up @@ -166,8 +177,6 @@ def get_dht_parser(add_help: bool = True) -> ArgumentParser:
help='Visible multiaddrs the host announces for external connections from other p2p instances')
parser.add_argument('--identity-path', type=Path,
help='Path to a pre-generated private key file. If defined, makes the peer ID deterministic')
parser.add_argument('--no-args-sync', action='store_true',
help='Do not sync base collaborative arguments with torch.hub')

return parser

Expand Down Expand Up @@ -215,6 +224,7 @@ def get_tune_parser(add_help: bool = True) -> ArgumentParser:
def get_auxiliary_parser(add_help: bool = True) -> ArgumentParser:
parser = ArgumentParser('Auxiliary arguments', add_help=add_help,
parents=[
get_sync_parser(add_help=False),
get_dht_parser(add_help=False),
get_model_parser(add_help=False),
get_data_parser(add_help=False),
Expand All @@ -239,14 +249,12 @@ def get_visualization_parser(add_help: bool = True) -> ArgumentParser:

parser.add_argument('--delimiter', default='', type=str,
help='Visualization columns delimiter')
parser.add_argument('--visualize-every-step', default=5, type=int,
parser.add_argument('--visualize-every-step', default=None, type=int,
help='Perform visualization once in this many global steps.')
parser.add_argument('--visualizer-refresh-period', default=10, type=float,
help='Period (in seconds) to check for visualization.')
parser.add_argument('--assist-in-averaging', action='store_true',
help='If True, this peer will facilitate averaging for other (training) peers')
parser.add_argument('--visualize', action='store_true',
help='If True, this peer will perform train progress visualization')

return parser

Expand All @@ -266,6 +274,7 @@ def get_monitor_parser(add_help: bool = True) -> ArgumentParser:
def get_train_parser(add_help: bool = True) -> ArgumentParser:
parser = ArgumentParser('Train loop arguments', add_help=add_help,
parents=[
get_sync_parser(add_help=False),
get_dht_parser(add_help=False),
get_model_parser(add_help=False),
get_data_parser(add_help=False),
Expand Down
13 changes: 9 additions & 4 deletions src/trecover/train/collab/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ class CollabCheckpoint(Callback):
def __init__(self,
dht_manager: DHTManager,
statistics_expiration: float,
backup_every_step: int):
backup_every_step: int,
sync_period: Optional[int] = None):
self.dht_manager: DHTManager = dht_manager
self.wrapped_model: Optional[BaseModelWrapper] = None
self.collab_opt: Optional[CollaborativeOptimizer] = None
Expand All @@ -36,6 +37,7 @@ def __init__(self,
self.samples_per_second = 0
self.alive_peers = 0
self.backup_every_step = backup_every_step
self.sync_period = sync_period

def on_train_batch_end(self,
trainer: pl.Trainer,
Expand All @@ -47,13 +49,11 @@ def on_train_batch_end(self,
assert len(trainer.strategy.optimizers) == 1, 'Hivemind only supports training with one optimizer.'
self.collab_opt = trainer.strategy.collab_opt
self.last_reported_step = self.collab_opt.local_epoch
self.min_noise = self.collab_opt.args.min_noise
self.max_noise = self.collab_opt.args.max_noise

if not self.collab_opt.params_are_finite:
log.project_console.print('Model parameters are not finite', style='red', justify='right')
self.collab_opt.recover_state()
return # TODO reset accumulated metrics?
return

self.steps += 1
self.loss += outputs['loss'].item()
Expand All @@ -67,6 +67,9 @@ def on_train_batch_end(self,
else:
log.project_console.print('Skip backup', style='yellow', justify='right')

if self.sync_period and current_step % self.sync_period == 0:
self.collab_opt.sync_collate()

self.last_reported_step = current_step

self.samples = self.collab_opt.local_samples_accumulated
Expand All @@ -75,6 +78,8 @@ def _report_metrics(self, step: int) -> None:
self.total_samples_processed += self.samples
self.samples_per_second = self.collab_opt.samples_per_second
self.lr = self.collab_opt.lr
self.min_noise = self.collab_opt.min_noise
self.max_noise = self.collab_opt.max_noise
self.alive_peers = self.collab_opt.num_peers

statistics = LocalMetrics(
Expand Down
13 changes: 9 additions & 4 deletions src/trecover/train/collab/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
def monitor(cli_args: Optional[List[str]] = None) -> None:
args = arguments.sync_base_args(arguments.get_monitor_parser().parse_args(cli_args))

args.sync_args = True # TODO

if args.assist_in_averaging and args.client_mode:
log.project_console.print('Client-mode peers cannot assist in averaging', style='red')
return
Expand All @@ -30,7 +32,7 @@ def monitor(cli_args: Optional[List[str]] = None) -> None:
aux_opt = None
visualizer = None

if args.upload_state or args.assist_in_averaging or args.visualize:
if args.upload_state or args.assist_in_averaging or args.visualize_every_step:
aux_opt = AuxiliaryOptimizer(dht=dht_manager.dht,
wrapped_model=BaseModelWrapper(args),
args=args,
Expand All @@ -39,7 +41,7 @@ def monitor(cli_args: Optional[List[str]] = None) -> None:
if args.assist_in_averaging:
aux_opt.start_assistant()

if args.visualize:
if args.visualize_every_step:
visualizer = CollaborativeVisualizer(aux_opt=aux_opt,
delimiter=args.delimiter,
visualize_every_step=args.visualize_every_step,
Expand Down Expand Up @@ -71,7 +73,7 @@ def monitor(cli_args: Optional[List[str]] = None) -> None:
finally:
if aux_opt and args.assist_in_averaging:
aux_opt.finish(join=True)
if visualizer and args.visualize:
if visualizer and args.visualize_every_step:
visualizer.finish(join=True)

common_status.disable()
Expand All @@ -80,6 +82,8 @@ def monitor(cli_args: Optional[List[str]] = None) -> None:
def train(cli_args: Optional[List[str]] = None) -> None:
args = arguments.sync_base_args(arguments.get_train_parser().parse_args(cli_args))

args.sync_args = True # TODO

os.system('ulimit -n 16384')

if args.batch_size is None:
Expand All @@ -91,7 +95,8 @@ def train(cli_args: Optional[List[str]] = None) -> None:

collab_checkpoint = CollabCheckpoint(dht_manager=dht_manager,
statistics_expiration=args.statistics_expiration,
backup_every_step=args.backup_every_step)
backup_every_step=args.backup_every_step,
sync_period=args.sync_period if args.sync_args else None)

trainer = pl.Trainer(default_root_dir=args.pl_registry,
max_epochs=args.n_epochs,
Expand Down
27 changes: 24 additions & 3 deletions src/trecover/train/collab/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,7 @@ def lr(self) -> float:
return self.opt.opt.param_groups[0]['lr']

@property
@atomic
def bandwidth(self) -> Optional[float]:
if not self.args.bandwidth:
try:
Expand All @@ -534,6 +535,16 @@ def bandwidth(self) -> Optional[float]:

return self.args.bandwidth

@property
@atomic
def min_noise(self) -> int:
return self.wrapped_model.collate.min_noise

@property
@atomic
def max_noise(self) -> int:
return self.wrapped_model.collate.max_noise

@atomic
def recover_state(self) -> None:
log.project_console.print('Trying to recover collab state...', style='yellow', justify='right')
Expand Down Expand Up @@ -565,6 +576,10 @@ def params_are_finite(self) -> bool:

return True

@atomic
def sync_collate(self) -> None:
self.wrapped_model.collate.sync(verbose=True)

@torch.no_grad()
@atomic
def sync_state(self) -> None:
Expand Down Expand Up @@ -692,23 +707,30 @@ def _assist_averaging_in_background(self) -> None:
except KeyboardInterrupt:
pass
finally:
self.stopped.set()
self.status.update('Is Stopped', style='yellow')
self.status.disable()

def _assistant_loop(self) -> None:
self.status.update('Pending...', style=self._status_style)

while not self.stopped.is_set():
try:
self.status.update('Pending...', style=self._status_style)
with self.transaction:
if self.stopped.is_set():
return

self.status.update('Assist in averaging...', style=self._status_style)

self._update_state_sharing_status_step()
self._check_finiteness_step()
self.opt.step()

if self._is_time_to_backup:
self._backup_step()

self.status.update('Pending...', style=self._status_style)

time.sleep(self.args.assist_refresh)

except KeyboardInterrupt:
Expand All @@ -725,7 +747,6 @@ def _update_state_sharing_status_step(self) -> None:
style='yellow',
justify='right'
)
self.status.update('Assist in averaging...', style=self._status_style)

elif self.allow_state_sharing and self.num_peers == 1 and self.num_client_peers == 1:
log.project_console.print(
Expand Down Expand Up @@ -782,7 +803,7 @@ def _backup_step(self) -> None:
self.last_reported_step = self.local_epoch

finally:
self.status.update('Pending...', style=self._status_style)
self.status.update('Assist in averaging...', style=self._status_style)

def start_assistant(self, attach: bool = False) -> None:
if self.args.client_mode:
Expand Down
6 changes: 5 additions & 1 deletion src/trecover/train/collab/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ def stream(self) -> Generator[Tuple[int, List[Panel]], None, None]:
self.status.update('Need to synchronize this peer before visualization...')
self.aux_opt.sync_state()

self.aux_opt.sync_collate()

self.status.update(f'Perform visualization for {self.aux_opt.local_epoch - 1}-step')

self.steps_performance[self.aux_opt.local_epoch - 1] = self.aux_opt.wrapped_model.perform()
Expand Down Expand Up @@ -118,7 +120,7 @@ def _is_time_to_visualize(self) -> bool:
@property
def _need_to_sync(self) -> bool:
return (
self.aux_opt.local_epoch != self.aux_opt.global_epoch or
self.aux_opt.local_epoch != self.aux_opt.global_epoch or # TODO check without as-active-peer
self.aux_opt.original_allow_state_sharing and not self.aux_opt.allow_state_sharing
)

Expand Down Expand Up @@ -169,6 +171,8 @@ def _visualize_in_background(self) -> None:
except KeyboardInterrupt:
self.status.update('Stopping...', style='yellow')
finally:
self.stopped.set()

if self.steps_performance:
self.status.update(f'Trying to report {len(self.steps_performance)} delayed visualizations...',
style='yellow')
Expand Down
32 changes: 24 additions & 8 deletions src/trecover/train/collab/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torch.utils.data import DataLoader

from trecover.model import TRecover
from trecover.train.data import WikiDataset, StandardCollate
from trecover.train.data import WikiDataset, BaseCollate, StandardCollate, CollabCollate
from trecover.train.loss import CustomCrossEntropyLoss
from trecover.utils.train import transfer
from trecover.utils.transform import tensor_to_columns, tensor_to_target
Expand All @@ -24,8 +24,18 @@ def __init__(self, args: Namespace, *pl_args: Any, **pl_kwargs: Any):
self.model = TRecover(args.token_size, args.pe_max_len, args.n_layers, args.d_model,
args.n_heads, args.d_ff, args.dropout)
self.criterion = CustomCrossEntropyLoss(ignore_index=-1)
self.collate = StandardCollate(min_noise=args.min_noise, max_noise=args.max_noise)
self.batch_size = args.batch_size
self._collate = None

@property
def collate(self) -> BaseCollate:
if self._collate is None:
if self.args.sync_args:
self._collate = CollabCollate()
else:
self._collate = StandardCollate(min_noise=self.args.min_noise, max_noise=self.args.max_noise)

return self._collate

def forward(self, batch: Tuple[Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor], Tensor]
) -> Tuple[Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor], Tensor, Tensor]:
Expand Down Expand Up @@ -66,16 +76,19 @@ def perform(self) -> List[Tuple[List[str], List[str], List[str]]]:
return performance

def performance_dataloader(self) -> DataLoader:
return self._create_dataloader(self.args.vis_files, self.args.vis_dataset_size, batch_size=self.batch_size or 1)
return self._create_dataloader(files=self.args.vis_files,
dataset_size=self.args.vis_dataset_size,
batch_size=self.batch_size or 1,
num_workers=self.args.n_workers)

def _create_dataloader(self, files: Path, dataset_size: int, batch_size: int) -> DataLoader:
def _create_dataloader(self, files: Path, dataset_size: int, batch_size: int, num_workers: int) -> DataLoader:
files = [files / file for file in files.iterdir()]
dataset = WikiDataset(datafiles=files, min_threshold=self.args.min_threshold,
max_threshold=self.args.max_threshold, dataset_size=dataset_size)

return dataset.create_dataloader(batch_size=batch_size,
collate=self.collate,
num_workers=self.args.n_workers)
num_workers=num_workers)


class PeerModelWrapper(BaseModelWrapper):
Expand Down Expand Up @@ -111,10 +124,12 @@ def validation_step(self, batch: Tuple[Tensor, Tensor, Tensor, Optional[Tensor],
return {'loss': loss, 'accuracy': accuracy}

def train_dataloader(self) -> DataLoader:
return self._create_dataloader(self.args.train_files, self.args.train_dataset_size, self.batch_size)
return self._create_dataloader(self.args.train_files, self.args.train_dataset_size, self.batch_size,
self.args.n_workers)

def val_dataloader(self) -> DataLoader:
return self._create_dataloader(self.args.val_files, self.args.val_dataset_size, self.batch_size)
return self._create_dataloader(self.args.val_files, self.args.val_dataset_size, self.batch_size,
self.args.n_workers)


class FullModelWrapper(PeerModelWrapper):
Expand All @@ -136,4 +151,5 @@ def test_step(self, batch: Tuple[Tensor, Tensor, Tensor, Optional[Tensor], Optio
return {'loss': loss, 'accuracy': accuracy}

def test_dataloader(self) -> DataLoader:
return self._create_dataloader(self.args.test_files, self.args.test_dataset_size, self.batch_size)
return self._create_dataloader(self.args.test_files, self.args.test_dataset_size, self.batch_size,
self.args.n_workers)
Loading

0 comments on commit 8f2cbb3

Please sign in to comment.