-
Notifications
You must be signed in to change notification settings - Fork 35
Description
Hi,
I have a classification model trained on the shape batch_size x N (=46) x channel x height x width. How can I adapt TRAK to use for that ? I get an error in the featurize function itself. Do I have to modify the in_dims?
/opt/conda/lib/python3.8/site-packages/trak/gradient_computers.py in compute_per_sample_grad(self, batch)
148
149 # map over batch dimensions (hence 0 for each batch dimension, and None for model params)
--> 150 grads = torch.func.vmap(
151 grads_loss,
152 in_dims=(None, None, None, *([0] * len(batch))),
/opt/conda/lib/python3.8/site-packages/torch/_functorch/vmap.py in wrapped(*args, **kwargs)
432
433 # If chunk_size is not specified.
--> 434 return _flat_vmap(
435 func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs
436 )
/opt/conda/lib/python3.8/site-packages/torch/_functorch/vmap.py in fn(*args, **kwargs)
37 def fn(*args, **kwargs):
38 with torch.autograd.graph.disable_saved_tensors_hooks(message):
---> 39 return f(*args, **kwargs)
40 return fn
41
/opt/conda/lib/python3.8/site-packages/torch/_functorch/vmap.py in _flat_vmap(func, batch_size, flat_in_dims, flat_args, args_spec, out_dims, randomness, **kwargs)
617 try:
618 batched_inputs = _create_batched_inputs(flat_in_dims, flat_args, vmap_level, args_spec)
--> 619 batched_outputs = func(*batched_inputs, **kwargs)
620 return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func)
621 finally:
/opt/conda/lib/python3.8/site-packages/torch/functorch/eager_transforms.py in wrapper(*args, **kwargs)
1378 @wraps(func)
1379 def wrapper(*args, **kwargs):
-> 1380 results = grad_and_value(func, argnums, has_aux=has_aux)(*args, **kwargs)
1381 if has_aux:
1382 grad, (, aux) = results
/opt/conda/lib/python3.8/site-packages/torch/_functorch/vmap.py in fn(*args, **kwargs)
37 def fn(*args, **kwargs):
38 with torch.autograd.graph.disable_saved_tensors_hooks(message):
---> 39 return f(*args, **kwargs)
40 return fn
41
/opt/conda/lib/python3.8/site-packages/torch/functorch/eager_transforms.py in wrapper(*args, **kwargs)
1243 tree_map(partial(_create_differentiable, level=level), diff_args)
1244
-> 1245 output = func(*args, **kwargs)
1246 if has_aux:
1247 if not (isinstance(output, tuple) and len(output) == 2):
/opt/conda/lib/python3.8/site-packages/trak/modelout_functions.py in get_output(model, weights, buffers, image, label)
138 """
139 logits = ch.func.functional_call(model, (weights, buffers), image.unsqueeze(0))
--> 140 bindex = ch.arange(logits.shape[0]).to(logits.device, non_blocking=False)
141 logits_correct = logits[bindex, label.unsqueeze(0)]
142
AttributeError: 'tuple' object has no attribute 'shape'