ysautoml.data.dsbn

funtional

ysautoml.data.dsbn.convert_and_wrap

ysautoml.data.dsbn.convert_and_wrap(**kwargs)

Convert all BatchNorm2d layers in a model to DSBN2d (Domain-Specific BatchNorm) and set initial mode. DSBN maintains two sets of BN statistics (BN_S for source, BN_T for target/aug).

Parameters

  • model_or_name (str or nn.Module): Model name (factory-built, e.g. "resnet18_cifar") or an existing model instance.

  • dataset (str, default "CIFAR10"): Dataset tag (reserved for model factory use).

  • num_classes (int, default 10): Number of classes.

  • use_aug (bool, default False): If True and mode=None, initial mode is set to 2 (Target BN). Otherwise 1 (Source BN).

  • mode (int or None, default None): DSBN mode.

    • 1: use Source BN (BN_S)

    • 2: use Target BN (BN_T)

    • 3: split-half (first half BN_S, second half BN_T). Batch size must be even. If None, inferred from use_aug.

  • device (str, default "0"): Device spec (sets CUDA_VISIBLE_DEVICES and moves model).

  • export_path (str or None): If given, saves model.state_dict() to this path.

Returns

  • nn.Module: DSBN-converted model with mode set.


ysautoml.data.dsbn.train_with_dsbn

ysautoml.data.dsbn.train_with_dsbn(**kwargs)

Train a DSBN-converted model using either separate or mixed batch training.

  • Separate mode (mixed_batch=False): Each iteration uses one source batch (mode=1) and one target batch (mode=2).

  • Mixed mode (mixed_batch=True): Each batch contains source+target samples concatenated. Model runs with mode=3, applying BN_S to the first half and BN_T to the second half.

Parameters

  • model (nn.Module): DSBN-converted model.

  • train_loader_source (DataLoader): Source domain loader. In mixed mode, pass a mixed loader here.

  • train_loader_target (DataLoader, optional): Target loader (required for separate mode).

  • epochs (int, default 1): Number of training epochs.

  • lr (float, default 0.1): Learning rate.

  • mixed_batch (bool, default False): If True, expects mixed batches and uses split-half mode.

  • device (str, default "cuda"): Training device.

  • log_interval (int, default 10): Print loss every log_interval steps.

Returns

  • dict containing:

    • logs (list): Training logs per step.

      • Separate mode: (epoch, step, (loss_source, loss_target))

      • Mixed mode: (epoch, step, loss)

    • final_acc (float): Final accuracy measured on source loader.

    • state_dict (OrderedDict): Trained model parameters.


Examples


Training log

Last updated