[New Feature] Support SAM 2.1 (#59)
* support sam 2.1 * refine config path and ckpt path * update README
This commit is contained in:
502
training/optimizer.py
Normal file
502
training/optimizer.py
Normal file
@@ -0,0 +1,502 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import fnmatch
|
||||
import inspect
|
||||
import itertools
|
||||
import logging
|
||||
import types
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
import hydra
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from omegaconf import DictConfig
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
class Optimizer:
|
||||
def __init__(self, optimizer, schedulers=None) -> None:
|
||||
self.optimizer = optimizer
|
||||
self.schedulers = schedulers
|
||||
self._validate_optimizer_schedulers()
|
||||
self.step_schedulers(0.0, 0)
|
||||
|
||||
def _validate_optimizer_schedulers(self):
|
||||
if self.schedulers is None:
|
||||
return
|
||||
for _, set_of_schedulers in enumerate(self.schedulers):
|
||||
for option, _ in set_of_schedulers.items():
|
||||
assert option in self.optimizer.defaults, (
|
||||
"Optimizer option "
|
||||
f"{option} not found in {self.optimizer}. Valid options are "
|
||||
f"{self.optimizer.defaults.keys()}"
|
||||
)
|
||||
|
||||
def step_schedulers(self, where: float, step: int) -> None:
|
||||
if self.schedulers is None:
|
||||
return
|
||||
for i, param_group in enumerate(self.optimizer.param_groups):
|
||||
for option, scheduler in self.schedulers[i].items():
|
||||
if "step" in inspect.signature(scheduler.__call__).parameters:
|
||||
new_value = scheduler(step=step, where=where)
|
||||
elif (
|
||||
hasattr(scheduler, "scheduler")
|
||||
and "step"
|
||||
in inspect.signature(scheduler.scheduler.__call__).parameters
|
||||
):
|
||||
# To handle ValueScaler wrappers
|
||||
new_value = scheduler(step=step, where=where)
|
||||
else:
|
||||
new_value = scheduler(where)
|
||||
param_group[option] = new_value
|
||||
|
||||
def step(self, where, step, closure=None):
|
||||
self.step_schedulers(where, step)
|
||||
return self.optimizer.step(closure)
|
||||
|
||||
def zero_grad(self, *args, **kwargs):
|
||||
return self.optimizer.zero_grad(*args, **kwargs)
|
||||
|
||||
|
||||
def set_default_parameters(
|
||||
scheduler_cfgs: List[DictConfig], all_parameter_names: Set[str]
|
||||
) -> None:
|
||||
"""Set up the "default" scheduler with the right parameters.
|
||||
|
||||
Args:
|
||||
scheduler_cgfs: A list of scheduler configs, where each scheduler also
|
||||
specifies which parameters it applies to, based on the names of parameters
|
||||
or the class of the modules. At most one scheduler is allowed to skip this
|
||||
specification, which is used as a "default" specification for any remaining
|
||||
parameters.
|
||||
all_parameter_names: Names of all the parameters to consider.
|
||||
"""
|
||||
constraints = [
|
||||
scheduler_cfg.parameter_names
|
||||
for scheduler_cfg in scheduler_cfgs
|
||||
if scheduler_cfg.parameter_names is not None
|
||||
]
|
||||
if len(constraints) == 0:
|
||||
default_params = set(all_parameter_names)
|
||||
else:
|
||||
default_params = all_parameter_names - set.union(*constraints)
|
||||
default_count = 0
|
||||
for scheduler_cfg in scheduler_cfgs:
|
||||
if scheduler_cfg.parameter_names is None:
|
||||
scheduler_cfg.parameter_names = default_params
|
||||
default_count += 1
|
||||
assert default_count <= 1, "Only one scheduler per option can be default"
|
||||
if default_count == 0:
|
||||
# No default scheduler specified, add a default, but without any scheduler
|
||||
# for that option
|
||||
scheduler_cfgs.append({"parameter_names": default_params})
|
||||
|
||||
|
||||
def name_constraints_to_parameters(
|
||||
param_constraints: List[Set[str]], named_parameters: Dict[str, Tensor]
|
||||
) -> List[torch.nn.Parameter]:
|
||||
"""Return parameters which match the intersection of parameter constraints.
|
||||
|
||||
Note that this returns the parameters themselves, not their names.
|
||||
|
||||
Args:
|
||||
param_constraints: A list, with each element being a set of allowed parameters.
|
||||
named_parameters: Mapping from a parameter name to the parameter itself.
|
||||
|
||||
Returns:
|
||||
A list containing the parameters which overlap with _each_ constraint set from
|
||||
param_constraints.
|
||||
"""
|
||||
matching_names = set.intersection(*param_constraints)
|
||||
return [value for name, value in named_parameters.items() if name in matching_names]
|
||||
|
||||
|
||||
def map_scheduler_cfgs_to_param_groups(
|
||||
all_scheduler_cfgs: Iterable[List[Dict]],
|
||||
named_parameters: Dict[str, Tensor],
|
||||
) -> Tuple[List[Dict[Any, Any]], List[Dict[str, List[torch.nn.Parameter]]]]:
|
||||
"""Produce parameter groups corresponding to all the scheduler configs.
|
||||
|
||||
Takes all the scheduler configs, each of which applies to a specific optimizer
|
||||
option (like "lr" or "weight_decay") and has a set of parameter names which it
|
||||
applies to, and produces a final set of param groups where each param group
|
||||
covers all the options which apply to a particular set of parameters.
|
||||
|
||||
Args:
|
||||
all_scheduler_cfgs: All the scheduler configs covering every option.
|
||||
named_parameters: Mapping from a parameter name to the parameter itself.
|
||||
Returns:
|
||||
Tuple of lists of schedulers and param_groups, where schedulers[i]
|
||||
applies to param_groups[i].
|
||||
"""
|
||||
|
||||
scheduler_cfgs_per_param_group = itertools.product(*all_scheduler_cfgs)
|
||||
schedulers = []
|
||||
param_groups = []
|
||||
for scheduler_cfgs in scheduler_cfgs_per_param_group:
|
||||
param_constraints = [
|
||||
scheduler_cfg["parameter_names"] for scheduler_cfg in scheduler_cfgs
|
||||
]
|
||||
matching_parameters = name_constraints_to_parameters(
|
||||
param_constraints, named_parameters
|
||||
)
|
||||
if len(matching_parameters) == 0: # If no overlap of parameters, skip
|
||||
continue
|
||||
schedulers_for_group = {
|
||||
scheduler_cfg["option"]: scheduler_cfg["scheduler"]
|
||||
for scheduler_cfg in scheduler_cfgs
|
||||
if "option" in scheduler_cfg
|
||||
}
|
||||
schedulers.append(schedulers_for_group)
|
||||
param_groups.append({"params": matching_parameters})
|
||||
return schedulers, param_groups
|
||||
|
||||
|
||||
def validate_param_group_params(param_groups: List[Dict], model: nn.Module):
|
||||
"""Check that the param groups are non-overlapping and cover all the parameters.
|
||||
|
||||
Args:
|
||||
param_groups: List of all param groups
|
||||
model: Model to validate against. The check ensures that all the model
|
||||
parameters are part of param_groups
|
||||
"""
|
||||
for pg in param_groups:
|
||||
# no param should be repeated within a group
|
||||
assert len(pg["params"]) == len(set(pg["params"]))
|
||||
parameters = [set(param_group["params"]) for param_group in param_groups]
|
||||
model_parameters = {parameter for _, parameter in model.named_parameters()}
|
||||
for p1, p2 in itertools.permutations(parameters, 2):
|
||||
assert p1.isdisjoint(p2), "Scheduler generated param_groups should be disjoint"
|
||||
assert set.union(*parameters) == model_parameters, (
|
||||
"Scheduler generated param_groups must include all parameters of the model."
|
||||
f" Found {len(set.union(*parameters))} params whereas model has"
|
||||
f" {len(model_parameters)} params"
|
||||
)
|
||||
|
||||
|
||||
def unix_module_cls_pattern_to_parameter_names(
|
||||
filter_module_cls_names: List[str],
|
||||
module_cls_to_param_names: Dict[Type, str],
|
||||
) -> Union[None, Set[str]]:
|
||||
"""Returns param names which pass the filters specified in filter_module_cls_names.
|
||||
|
||||
Args:
|
||||
filter_module_cls_names: A list of filter strings containing class names, like
|
||||
["torch.nn.LayerNorm", "torch.nn.BatchNorm2d"]
|
||||
module_cls_to_param_names: Mapping from module classes to the parameter names
|
||||
they contain. See `get_module_cls_to_param_names`.
|
||||
"""
|
||||
if filter_module_cls_names is None:
|
||||
return set()
|
||||
allowed_parameter_names = []
|
||||
for module_cls_name in filter_module_cls_names:
|
||||
module_cls = hydra.utils.get_class(module_cls_name)
|
||||
if module_cls not in module_cls_to_param_names:
|
||||
raise AssertionError(
|
||||
f"module_cls_name {module_cls_name} does not "
|
||||
"match any classes in the model"
|
||||
)
|
||||
matching_parameters = module_cls_to_param_names[module_cls]
|
||||
assert (
|
||||
len(matching_parameters) > 0
|
||||
), f"module_cls_name {module_cls_name} does not contain any parameters in the model"
|
||||
logging.info(
|
||||
f"Matches for module_cls_name [{module_cls_name}]: {matching_parameters} "
|
||||
)
|
||||
allowed_parameter_names.append(matching_parameters)
|
||||
return set.union(*allowed_parameter_names)
|
||||
|
||||
|
||||
def unix_param_pattern_to_parameter_names(
|
||||
filter_param_names: Optional[List[str]],
|
||||
parameter_names: Dict[str, torch.Tensor],
|
||||
) -> Union[None, Set[str]]:
|
||||
"""Returns param names which pass the filters specified in filter_param_names.
|
||||
|
||||
Args:
|
||||
filter_param_names: A list of unix-style filter strings with optional
|
||||
wildcards, like ["block.2.*", "block.2.linear.weight"]
|
||||
module_cls_to_param_names: Mapping from module classes to the parameter names
|
||||
they contain. See `get_module_cls_to_param_names`.
|
||||
"""
|
||||
|
||||
if filter_param_names is None:
|
||||
return set()
|
||||
allowed_parameter_names = []
|
||||
for param_name in filter_param_names:
|
||||
matching_parameters = set(fnmatch.filter(parameter_names, param_name))
|
||||
assert (
|
||||
len(matching_parameters) >= 1
|
||||
), f"param_name {param_name} does not match any parameters in the model"
|
||||
logging.info(f"Matches for param_name [{param_name}]: {matching_parameters}")
|
||||
allowed_parameter_names.append(matching_parameters)
|
||||
return set.union(*allowed_parameter_names)
|
||||
|
||||
|
||||
def _unix_pattern_to_parameter_names(
|
||||
scheduler_cfg: DictConfig,
|
||||
parameter_names: Set[str],
|
||||
module_cls_to_param_names: Dict[Type, str],
|
||||
) -> Union[None, Set[str]]:
|
||||
"""Returns param names which pass the filters specified in scheduler_cfg.
|
||||
|
||||
Args:
|
||||
scheduler_cfg: The config for the scheduler
|
||||
parameter_names: The set of all parameter names which will be filtered
|
||||
"""
|
||||
if "param_names" not in scheduler_cfg and "module_cls_names" not in scheduler_cfg:
|
||||
return None
|
||||
return unix_param_pattern_to_parameter_names(
|
||||
scheduler_cfg.get("param_names"), parameter_names
|
||||
).union(
|
||||
unix_module_cls_pattern_to_parameter_names(
|
||||
scheduler_cfg.get("module_cls_names"), module_cls_to_param_names
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def get_module_cls_to_param_names(
|
||||
model: nn.Module, param_allowlist: Set[str] = None
|
||||
) -> Dict[Type, str]:
|
||||
"""Produce a mapping from all the modules classes to the names of parames they own.
|
||||
|
||||
Only counts a parameter as part of the immediate parent module, i.e. recursive
|
||||
parents do not count.
|
||||
|
||||
Args:
|
||||
model: Model to iterate over
|
||||
param_allowlist: If specified, only these param names will be processed
|
||||
"""
|
||||
|
||||
module_cls_to_params = {}
|
||||
for module_name, module in model.named_modules():
|
||||
module_cls = type(module)
|
||||
module_cls_to_params.setdefault(module_cls, set())
|
||||
for param_name, _ in module.named_parameters(recurse=False):
|
||||
full_param_name = get_full_parameter_name(module_name, param_name)
|
||||
if param_allowlist is None or full_param_name in param_allowlist:
|
||||
module_cls_to_params[module_cls].add(full_param_name)
|
||||
return module_cls_to_params
|
||||
|
||||
|
||||
def construct_optimizer(
|
||||
model: torch.nn.Module,
|
||||
optimizer_conf: Any,
|
||||
options_conf: Mapping[str, List] = None,
|
||||
param_group_modifiers_conf: List[Callable] = None,
|
||||
param_allowlist: Optional[Set[str]] = None,
|
||||
validate_param_groups=True,
|
||||
) -> Optimizer:
|
||||
"""
|
||||
Constructs a stochastic gradient descent or ADAM (or ADAMw) optimizer
|
||||
with momentum. i.e, constructs a torch.optim.Optimizer with zero-weight decay
|
||||
Batchnorm and/or no-update 1-D parameters support, based on the config.
|
||||
|
||||
Supports wrapping the optimizer with Layer-wise Adaptive Rate Scaling
|
||||
(LARS): https://arxiv.org/abs/1708.03888
|
||||
|
||||
Args:
|
||||
model: model to perform stochastic gradient descent
|
||||
optimization or ADAM optimization.
|
||||
optimizer_conf: Hydra config consisting a partial torch optimizer like SGD or
|
||||
ADAM, still missing the params argument which this function provides to
|
||||
produce the final optimizer
|
||||
param_group_modifiers_conf: Optional user specified functions which can modify
|
||||
the final scheduler configs before the optimizer's param groups are built
|
||||
param_allowlist: The parameters to optimize. Parameters which are not part of
|
||||
this allowlist will be skipped.
|
||||
validate_param_groups: If enabled, valides that the produced param_groups don't
|
||||
overlap and cover all the model parameters.
|
||||
"""
|
||||
if param_allowlist is None:
|
||||
param_allowlist = {name for name, _ in model.named_parameters()}
|
||||
|
||||
named_parameters = {
|
||||
name: param
|
||||
for name, param in model.named_parameters()
|
||||
if name in param_allowlist
|
||||
}
|
||||
|
||||
if not options_conf:
|
||||
optimizer = hydra.utils.instantiate(optimizer_conf, named_parameters.values())
|
||||
return Optimizer(optimizer)
|
||||
|
||||
all_parameter_names = {
|
||||
name for name, _ in model.named_parameters() if name in param_allowlist
|
||||
}
|
||||
module_cls_to_all_param_names = get_module_cls_to_param_names(
|
||||
model, param_allowlist
|
||||
)
|
||||
|
||||
scheduler_cfgs_per_option = hydra.utils.instantiate(options_conf)
|
||||
all_scheduler_cfgs = []
|
||||
for option, scheduler_cfgs in scheduler_cfgs_per_option.items():
|
||||
for config in scheduler_cfgs:
|
||||
config.option = option
|
||||
config.parameter_names = _unix_pattern_to_parameter_names(
|
||||
config, all_parameter_names, module_cls_to_all_param_names
|
||||
)
|
||||
set_default_parameters(scheduler_cfgs, all_parameter_names)
|
||||
all_scheduler_cfgs.append(scheduler_cfgs)
|
||||
|
||||
if param_group_modifiers_conf:
|
||||
for custom_param_modifier in param_group_modifiers_conf:
|
||||
custom_param_modifier = hydra.utils.instantiate(custom_param_modifier)
|
||||
all_scheduler_cfgs = custom_param_modifier(
|
||||
scheduler_cfgs=all_scheduler_cfgs, model=model
|
||||
)
|
||||
schedulers, param_groups = map_scheduler_cfgs_to_param_groups(
|
||||
all_scheduler_cfgs, named_parameters
|
||||
)
|
||||
if validate_param_groups:
|
||||
validate_param_group_params(param_groups, model)
|
||||
optimizer = hydra.utils.instantiate(optimizer_conf, param_groups)
|
||||
return Optimizer(optimizer, schedulers)
|
||||
|
||||
|
||||
def get_full_parameter_name(module_name, param_name):
|
||||
if module_name == "":
|
||||
return param_name
|
||||
return f"{module_name}.{param_name}"
|
||||
|
||||
|
||||
class GradientClipper:
|
||||
"""
|
||||
Gradient clipping utils that works for DDP
|
||||
"""
|
||||
|
||||
def __init__(self, max_norm: float = 1.0, norm_type: int = 2):
|
||||
assert isinstance(max_norm, (int, float)) or max_norm is None
|
||||
self.max_norm = max_norm if max_norm is None else float(max_norm)
|
||||
self.norm_type = norm_type
|
||||
|
||||
def __call__(self, model: nn.Module):
|
||||
if self.max_norm is None:
|
||||
return # no-op
|
||||
|
||||
nn.utils.clip_grad_norm_(
|
||||
model.parameters(), max_norm=self.max_norm, norm_type=self.norm_type
|
||||
)
|
||||
|
||||
|
||||
class ValueScaler:
|
||||
def __init__(self, scheduler, mult_val: float):
|
||||
self.scheduler = scheduler
|
||||
self.mult_val = mult_val
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
val = self.scheduler(*args, **kwargs)
|
||||
return val * self.mult_val
|
||||
|
||||
|
||||
def rgetattr(obj, rattrs: str = None):
|
||||
"""
|
||||
Like getattr(), but supports dotted notation for nested objects.
|
||||
rattrs is a str of form 'attr1.attr2', returns obj.attr1.attr2
|
||||
"""
|
||||
if rattrs is None:
|
||||
return obj
|
||||
attrs = rattrs.split(".")
|
||||
for attr in attrs:
|
||||
obj = getattr(obj, attr)
|
||||
return obj
|
||||
|
||||
|
||||
def layer_decay_param_modifier(
|
||||
scheduler_cfgs: List[List[Dict]],
|
||||
model,
|
||||
layer_decay_value: float,
|
||||
layer_decay_min: Optional[float] = None,
|
||||
apply_to: Optional[str] = None,
|
||||
overrides: List[Dict] = (),
|
||||
) -> List[List[Dict]]:
|
||||
"""
|
||||
Args
|
||||
- scheduler_cfgs: a list of omegaconf.ListConfigs.
|
||||
Each element in the list is a omegaconfg.DictConfig with the following structure
|
||||
{
|
||||
"scheduler": <some fvcore scheduler>
|
||||
"option": <value> possible options are "lr", "weight_decay" etc.
|
||||
"parameter_names": Set of str indicating param names that this scheduler applies to
|
||||
}
|
||||
- model: a model that implements a method `get_layer_id` that maps layer_name to an integer and
|
||||
and a method get_num_layers.
|
||||
Alternatively, use apply_to argument to select a specific component of the model.
|
||||
- layer_decay_value: float
|
||||
- layer_decay_min: min val for layer decay
|
||||
- apply_to: optional arg to select which component of the model to apply the the layer decay modifier to
|
||||
- overrides: to manually override lr for specific patterns. Is a list of dicts. Each dict, has keys "pattern", "value".
|
||||
Returns
|
||||
- scheduler_configs: same structure as the input, elements can be modified
|
||||
"""
|
||||
model = rgetattr(model, apply_to)
|
||||
num_layers = model.get_num_layers() + 1
|
||||
layer_decays = [
|
||||
layer_decay_value ** (num_layers - i) for i in range(num_layers + 1)
|
||||
]
|
||||
if layer_decay_min is not None:
|
||||
layer_decays = [max(val, layer_decay_min) for val in layer_decays]
|
||||
final_scheduler_cfgs = []
|
||||
# scheduler_cfgs is a list of lists
|
||||
for scheduler_cfg_group in scheduler_cfgs:
|
||||
curr_cfg_group = []
|
||||
# scheduler_cfg_group is a list of dictionaries
|
||||
for scheduler_cfg in scheduler_cfg_group:
|
||||
if scheduler_cfg["option"] != "lr":
|
||||
curr_cfg_group.append(scheduler_cfg)
|
||||
continue
|
||||
# Need sorted so that the list of parameter names is deterministic and consistent
|
||||
# across re-runs of this job. Else it was causing issues with loading the optimizer
|
||||
# state during a job restart (D38591759)
|
||||
parameter_names = sorted(scheduler_cfg["parameter_names"])
|
||||
|
||||
# Only want one cfg group per layer
|
||||
layer_cfg_groups = {}
|
||||
for param_name in parameter_names:
|
||||
layer_id = num_layers
|
||||
this_scale = layer_decays[layer_id]
|
||||
if param_name.startswith(apply_to):
|
||||
layer_id = model.get_layer_id(param_name)
|
||||
this_scale = layer_decays[layer_id]
|
||||
# Overrides
|
||||
for override in overrides:
|
||||
if fnmatch.fnmatchcase(param_name, override["pattern"]):
|
||||
this_scale = float(override["value"])
|
||||
layer_id = override["pattern"]
|
||||
break
|
||||
|
||||
if layer_id not in layer_cfg_groups:
|
||||
curr_param = {
|
||||
"option": scheduler_cfg["option"],
|
||||
"scheduler": ValueScaler(
|
||||
scheduler_cfg["scheduler"], this_scale
|
||||
),
|
||||
"parameter_names": {param_name},
|
||||
}
|
||||
else:
|
||||
curr_param = layer_cfg_groups[layer_id]
|
||||
curr_param["parameter_names"].add(param_name)
|
||||
layer_cfg_groups[layer_id] = curr_param
|
||||
|
||||
for layer_cfg in layer_cfg_groups.values():
|
||||
curr_cfg_group.append(layer_cfg)
|
||||
|
||||
final_scheduler_cfgs.append(curr_cfg_group)
|
||||
return final_scheduler_cfgs
|
Reference in New Issue
Block a user