ysautoml.data.dsbn
funtional
ysautoml.data.dsbn.convert_and_wrap
ysautoml.data.dsbn.convert_and_wrap(**kwargs)
Convert all BatchNorm2d layers in a model to DSBN2d (Domain-Specific BatchNorm) and set initial mode. DSBN maintains two sets of BN statistics (BN_S for source, BN_T for target/aug).
Parameters
model_or_name (str or nn.Module): Model name (factory-built, e.g.
"resnet18_cifar") or an existing model instance.dataset (str, default
"CIFAR10"): Dataset tag (reserved for model factory use).num_classes (int, default 10): Number of classes.
use_aug (bool, default False): If
Trueandmode=None, initial mode is set to 2 (Target BN). Otherwise 1 (Source BN).mode (int or None, default None): DSBN mode.
1: use Source BN (BN_S)2: use Target BN (BN_T)3: split-half (first half BN_S, second half BN_T). Batch size must be even. IfNone, inferred fromuse_aug.
device (str, default
"0"): Device spec (setsCUDA_VISIBLE_DEVICESand moves model).export_path (str or None): If given, saves
model.state_dict()to this path.
Returns
nn.Module: DSBN-converted model with mode set.
ysautoml.data.dsbn.train_with_dsbn
ysautoml.data.dsbn.train_with_dsbn(**kwargs)
Train a DSBN-converted model using either separate or mixed batch training.
Separate mode (
mixed_batch=False): Each iteration uses one source batch (mode=1) and one target batch (mode=2).Mixed mode (
mixed_batch=True): Each batch contains source+target samples concatenated. Model runs with mode=3, applying BN_S to the first half and BN_T to the second half.
Parameters
model (nn.Module): DSBN-converted model.
train_loader_source (DataLoader): Source domain loader. In mixed mode, pass a mixed loader here.
train_loader_target (DataLoader, optional): Target loader (required for separate mode).
epochs (int, default 1): Number of training epochs.
lr (float, default 0.1): Learning rate.
mixed_batch (bool, default False): If
True, expects mixed batches and uses split-half mode.device (str, default
"cuda"): Training device.log_interval (int, default 10): Print loss every
log_intervalsteps.
Returns
dict containing:
logs (list): Training logs per step.
Separate mode:
(epoch, step, (loss_source, loss_target))Mixed mode:
(epoch, step, loss)
final_acc (float): Final accuracy measured on source loader.
state_dict (OrderedDict): Trained model parameters.
Examples
Last updated