> For the complete documentation index, see [llms.txt](https://deeplearning4j.konduit.ai/llms.txt). Markdown versions of documentation pages are available by appending `.md` to page URLs; this page is available as [Markdown](https://deeplearning4j.konduit.ai/en-1.0.0-rewrite/deeplearning4j/peft-and-rl.md).

# PEFT & RL Alignment Training

Deeplearning4j 1.0.0-rewrite adds a complete suite of tools for fine-tuning large pretrained models without retraining all parameters, aligning language model outputs using reinforcement learning feedback, training with reduced numerical precision, and curating datasets for domain adaptation.

***

## 1. Overview

The following capability groups are covered in this page:

| Area                   | Key Classes                                           | What It Provides                                                     |
| ---------------------- | ----------------------------------------------------- | -------------------------------------------------------------------- |
| PEFT                   | `PeftModel`, `PeftModelFactory`, `LoraLayer`          | 13 adapter types; freeze base weights, train only adapter parameters |
| RL Alignment           | `GRPOTrainer`, `DPOTrainer`, `PPOTrainer`, and 7 more | Human-preference and reward-signal alignment                         |
| Training Pipelines     | `SFTTrainingPipeline`, `RLAlignmentPipeline`          | End-to-end supervised and RLHF workflows                             |
| Mixed Precision        | `FP8ScaleManager`, `LossScaler`                       | FP8 forward/backward, dynamic loss scaling                           |
| 8-bit Adam             | `Adam8bit`, `Adam8bitUpdater`                         | \~4x optimizer-state memory reduction                                |
| Knowledge Distillation | `DistillationTrainer`                                 | Teacher/student training with KL, attention, and feature losses      |
| Dataset Curation       | `TextDeduplicator`, `SequencePacker`, and 18 more     | Deduplication, contamination removal, curriculum, bin-packing        |
| Transfer Learning      | `TransferLearning`, `TransferLearningHelper`          | Layer freezing and head replacement on existing models               |

All SameDiff models gain these extension methods automatically:

```java
sd.applyPeft(PeftConfig config);
sd.getTrainableParameters();
sd.printTrainableParameters();
sd.distillFrom(SameDiff teacher, DistillationConfig config);
sd.fitGRPO(GRPOConfig config, MultiDataSetIterator data);
sd.saveAdapters(Path dir);
sd.loadAdapters(Path dir);
```

***

## 2. Maven Dependencies

```xml
<dependency>
    <groupId>org.deeplearning4j</groupId>
    <artifactId>deeplearning4j-core</artifactId>
    <version>1.0.0-rewrite</version>
</dependency>
<!-- PEFT and RL alignment trainers -->
<dependency>
    <groupId>org.deeplearning4j</groupId>
    <artifactId>deeplearning4j-peft</artifactId>
    <version>1.0.0-rewrite</version>
</dependency>
<!-- 8-bit Adam and FP8 mixed precision -->
<dependency>
    <groupId>org.nd4j</groupId>
    <artifactId>nd4j-cuda-12.x</artifactId>
    <version>1.0.0-rewrite</version>
</dependency>
```

***

## 3. PEFT Methods

### 3.1 PeftType Enum

Every PEFT method is identified by a `PeftType` constant:

| PeftType        | Description                                                             |
| --------------- | ----------------------------------------------------------------------- |
| `LORA`          | Low-Rank Adaptation — low-rank A/B matrices added to linear projections |
| `QLORA`         | LoRA on NF4-quantized base weights                                      |
| `DORA`          | Weight-Decomposed LoRA with learned magnitude vector                    |
| `LOHA`          | Low-rank Hadamard product adaptation                                    |
| `LOKR`          | Low-rank Kronecker product adaptation                                   |
| `IA3`           | Infused Adapter by Inhibiting and Amplifying Inner Activations          |
| `VERA`          | Vector-based Random Matrix Adaptation                                   |
| `PREFIX_TUNING` | Trainable key/value prefix tokens prepended to each layer               |
| `PROMPT_TUNING` | Soft prompt tokens prepended to the input embedding                     |
| `ADAPTER`       | Small bottleneck feed-forward modules inserted between layers           |
| `LOFTQ`         | LoRA initialized to minimize NF4 quantization error                     |
| `ADALORA`       | Adaptive rank allocation across layers                                  |
| `DYLORA`        | Dynamic rank search during training                                     |

### 3.2 PeftModel

`PeftModel` wraps an existing `SameDiff` base model and injects adapter layers without modifying the base graph:

```java
import org.deeplearning4j.peft.PeftModel;
import org.deeplearning4j.peft.config.LoraConfig;

SameDiff baseModel = SameDiff.load(Paths.get("llama-7b.fb"), true);

LoraConfig config = LoraConfig.builder()
    .r(16)
    .loraAlpha(32)
    .loraDropout(0.05)
    .targetModules(List.of("q_proj", "v_proj"))
    .biasMode(BiasMode.NONE)
    .initLoraWeights(InitLoraWeights.KAIMING)
    .build();

PeftModel peft = PeftModelFactory.fromConfig(baseModel, config);

// Print how many parameters are actually trained
peft.printTrainableParameters();
// Output: trainable params: 4,194,304 || all params: 6,742,609,920 || trainable%: 0.0623
```

Key `PeftModel` capabilities:

| Method                                       | Description                                                           |
| -------------------------------------------- | --------------------------------------------------------------------- |
| `printTrainableParameters()`                 | Logs trainable vs total parameter count and percentage                |
| `getTrainableParameterCount()`               | Returns the count of adapter (non-frozen) parameters                  |
| `addAdapter(String name, PeftConfig config)` | Inject an additional named adapter onto the same base model           |
| `setActiveAdapter(String name)`              | Switch which adapter is active for inference                          |
| `mergeAndUnload()`                           | Merge adapter weights into base weights and return a plain `SameDiff` |
| `saveAdapters(Path dir)`                     | Serialize only adapter weights (small files)                          |
| `loadAdapters(Path dir)`                     | Restore previously saved adapter weights                              |

Base model parameters are automatically frozen when `PeftModel` is created. Only the adapter parameters listed in `targetModules` are in the computation graph as trainable variables.

### 3.3 PeftModelFactory

`PeftModelFactory.fromConfig(SameDiff baseSd, PeftConfig config)` inspects `config.getPeftType()` and dispatches to the correct adapter injection strategy:

```java
// Dispatches based on PeftType
PeftModel model = PeftModelFactory.fromConfig(baseSd, loraConfig);   // LORA
PeftModel model = PeftModelFactory.fromConfig(baseSd, qloraConfig);  // QLORA
PeftModel model = PeftModelFactory.fromConfig(baseSd, adapConfig);   // ADAPTER
```

### 3.4 LoraLayer

`LoraLayer` is the core adapter unit for all LoRA variants. It wraps an existing linear op (matrix multiply) with low-rank side-path A and B matrices:

```
output = base_weight @ x + (lora_B @ lora_A @ x) * (alpha / r)
```

Fields tracked per layer:

| Field     | Description                                     |
| --------- | ----------------------------------------------- |
| `r`       | Rank of the factorization                       |
| `alpha`   | Scaling factor; effective scale = alpha / r     |
| `dropout` | Dropout probability applied to the adapter path |

`mergeWeights()` adds `(lora_B @ lora_A) * (alpha / r)` directly into the base weight matrix, eliminating the side path and returning a standard linear op with no inference overhead.

### 3.5 LoraAdapterCache

`LoraAdapterCache` provides thread-safe, LRU-evicted in-memory caching of adapter weights for multi-adapter serving scenarios:

```java
import org.deeplearning4j.peft.cache.LoraAdapterCache;

LoraAdapterCache cache = new LoraAdapterCache(maxAdapters /* e.g., 10 */);

// Load adapter into cache keyed by (modelId, adapterName)
cache.put("llama-7b", "customer-a-adapter", adapterWeights);

// Retrieve
Map<String, INDArray> weights = cache.get("llama-7b", "customer-a-adapter");
```

The LRU policy evicts the least-recently-used adapter when the cache is full, making it suitable for serving many fine-tuned variants from a single base model without loading all adapters into GPU memory simultaneously.

### 3.6 LoftQInitializer

`LoftQInitializer` produces LoRA A/B initializations that specifically minimize the error introduced by NF4 quantization:

```
error = base_weight - dequantize(nf4_quantize(base_weight))
Initialize A, B such that B @ A ≈ error
```

```java
import org.deeplearning4j.peft.init.LoftQInitializer;

LoftQInitializer init = new LoftQInitializer(numBits /* 4 */, numIter /* 5 */);
init.initialize(loraLayer, baseWeight);
```

Use `initLoraWeights(InitLoraWeights.LOFTQ)` in `LoraConfig` to apply this automatically via `PeftModelFactory`.

***

## 4. LoRA Deep Dive

### 4.1 LoraConfig

```java
import org.deeplearning4j.peft.config.LoraConfig;

LoraConfig config = LoraConfig.builder()
    .r(16)                                              // rank
    .loraAlpha(32)                                      // scaling; effective_scale = alpha/r = 2.0
    .loraDropout(0.05)                                  // dropout on adapter path
    .targetModules(List.of("q_proj", "k_proj",
                           "v_proj", "o_proj"))         // layers to adapt
    .biasMode(BiasMode.NONE)                            // NONE | ALL | LORA_ONLY
    .initLoraWeights(InitLoraWeights.KAIMING)           // GAUSSIAN | KAIMING | LOFTQ
    .build();
```

`biasMode` controls which bias vectors receive gradients:

| BiasMode    | Effect                                                |
| ----------- | ----------------------------------------------------- |
| `NONE`      | No bias vectors are trained (default)                 |
| `ALL`       | All bias vectors (base + adapter layers) are trained  |
| `LORA_ONLY` | Only bias vectors in LoRA-targeted layers are trained |

### 4.2 QLoraConfig

`QLoraConfig` extends `LoraConfig` and adds NF4 quantization of the base weights. The base model loads in 4-bit and stays frozen; only the LoRA adapter trains in full precision:

```java
import org.deeplearning4j.peft.config.QLoraConfig;

QLoraConfig config = QLoraConfig.builder()
    .r(64)
    .loraAlpha(16)
    .loraDropout(0.1)
    .targetModules(List.of("q_proj", "v_proj"))
    .quantBits(4)                   // NF4 quantization
    .doubleQuant(true)              // quantize the quantization constants as well
    .computeDtype(DataType.BFLOAT16)
    .build();
```

### 4.3 AdaLoraConfig — Adaptive Rank Allocation

AdaLoRA starts all adapters at `initR` and iteratively prunes ranks toward `targetR` using singular value decomposition-based importance scoring. Orthogonal regularization keeps the A/B decomposition valid:

```java
import org.deeplearning4j.peft.config.AdaLoraConfig;

AdaLoraConfig config = AdaLoraConfig.builder()
    .initR(12)                  // starting rank
    .targetR(4)                 // final average rank budget
    .deltaT(200)                // steps between rank updates
    .betaK(0.85)                // EMA coefficient for importance scores
    .totalSteps(10000)
    .orthogonalRegCoeff(0.1)    // coefficient for orthogonal regularization loss
    .targetModules(List.of("q_proj", "v_proj"))
    .build();
```

### 4.4 DyLoraConfig — Dynamic Rank

DyLoRA trains a single adapter that can serve any rank in `[minR, maxR]` by randomly sampling a rank each forward pass. This produces an adapter where you pick the serving rank at inference time based on latency vs. quality trade-offs:

```java
import org.deeplearning4j.peft.config.DyLoraConfig;

DyLoraConfig config = DyLoraConfig.builder()
    .minR(1)
    .maxR(16)
    .loraAlpha(16)
    .targetModules(List.of("q_proj", "v_proj"))
    .build();
```

### 4.5 Other Adapter Configs

| Config Class         | Key Parameters                                                                                          |
| -------------------- | ------------------------------------------------------------------------------------------------------- |
| `DoraConfig`         | `useDoraDecomposition`, `magnitudeLr` (learning rate for the magnitude vector)                          |
| `LohaConfig`         | Hadamard rank `r`, `lohaAlpha`, `targetModules`                                                         |
| `LokrConfig`         | Kronecker factor sizes, `alpha`, `targetModules`                                                        |
| `IA3Config`          | `targetModulesForFeedforward`, scales feedforward activations only                                      |
| `VeraConfig`         | `projectionRank`, per-layer scaling vectors; shared random projection matrices are not stored per-layer |
| `LoftQConfig`        | `numBits`, `numIter` for quantization error minimization                                                |
| `PrefixTuningConfig` | `numVirtualTokens`, `encoderHiddenSize`                                                                 |
| `PromptTuningConfig` | `numVirtualTokens`, `promptTuningInit` (RANDOM or TEXT)                                                 |
| `AdapterConfig`      | `bottleneckDim`, `nonLinearity`, placement (AFTER\_ATTN, AFTER\_FF, BOTH)                               |

### 4.6 VeRA — Vector-based Random Matrix Adaptation

VeRA (ADR 0077) shares a single pair of frozen random projection matrices across all targeted layers. Only per-layer learned scaling vectors `d` and `b` are trained, making it dramatically more parameter-efficient than standard LoRA.

The update formula per layer is:

```
delta_W = lambda * diag(d) @ B_shared @ diag(b) @ A_shared
```

where `A_shared` and `B_shared` are the same frozen random matrices for every layer. Only `d` and `b` (one scalar per row/column) are trainable.

```java
import org.nd4j.autodiff.samediff.config.VeraConfig;

VeraConfig veraConfig = VeraConfig.builder()
    .projectionRank(64)               // rank of the shared random matrices
    .targetModules(List.of("q_proj", "v_proj", "k_proj", "o_proj"))
    .loraDropout(0.05)
    .build();

PeftModel peft = PeftModelFactory.fromConfig(baseModel, veraConfig);
peft.printTrainableParameters();
// Trainable params are only the d and b vectors: ~2 * num_layers * rank
```

VeRA requires orders of magnitude fewer trainable parameters than LoRA at the same rank, making it well suited for scenarios where adapter storage or communication is a bottleneck.

### 4.7 DyLoRA — Dynamic Rank Training

DyLoRA (ADR 0077) trains a single adapter with a rank range `[minRank, maxR]`. During each forward pass a rank is sampled uniformly from the range. After training, the same adapter checkpoint can be served at any rank in the range, letting you trade quality against latency at deployment time.

```java
import org.nd4j.autodiff.samediff.config.DyLoraConfig;

DyLoraConfig dyloraConfig = DyLoraConfig.builder()
    .r(16)                            // maximum rank (upper bound of the range)
    .minRank(1)                       // minimum rank to sample during training
    .loraAlpha(16)
    .loraDropout(0.05)
    .targetModules(List.of("q_proj", "v_proj"))
    .build();

PeftModel peft = PeftModelFactory.fromConfig(baseModel, dyloraConfig);
// Train normally — rank is sampled each step
peft.fit(trainIterator, numEpochs);

// At inference, truncate to the rank that fits your latency budget
// (e.g., use only the first 4 singular values)
```

### 4.8 LoRA+ — Differential Learning Rates

LoRA+ (ADR 0077) applies a higher learning rate to the B matrix than to the A matrix. The B matrix lies closer to the output and benefits from faster updates; keeping A at the base learning rate stabilises training. Set `loraLrRatioB` in `LoraConfig` and register the multiplier in `TrainingConfig`:

```java
import org.nd4j.autodiff.samediff.config.LoraConfig;
import org.nd4j.autodiff.samediff.TrainingConfig;

LoraConfig loraPlusConfig = LoraConfig.builder()
    .r(16)
    .loraAlpha(32)
    .loraDropout(0.05)
    .targetModules(List.of("q_proj", "v_proj"))
    .loraLrRatioB(16.0)    // B matrix gets 16x the base learning rate
    .build();

PeftModel peft = PeftModelFactory.fromConfig(baseModel, loraPlusConfig);

// TrainingConfig picks up per-variable LR multipliers injected during adapter creation
TrainingConfig trainConfig = TrainingConfig.builder()
    .updater(new Adam(2e-5))           // base LR for A matrices
    // learningRateMultipliers for B variables are set automatically by PeftModelFactory
    .dataSetFeatureMapping("input_ids")
    .dataSetLabelMapping("labels")
    .build();
```

### 4.9 PiSSA and OLoRA Initialization

Both PiSSA and OLoRA (ADR 0077) produce better starting points for LoRA adapters than random initialization by exploiting the structure of the base weight matrix.

**PiSSA** (Principal Singular-value and Singular-vector Adaptation) runs a truncated SVD on the base weight and initialises A and B from the top-r singular triplets. The base weight is then replaced by its residual `W_residual = W - B @ A`, so the adapter begins training with zero initial error.

**OLoRA** performs the same residual correction but uses QR decomposition instead of SVD, which is cheaper for wide matrices.

Select either via `initLoraWeights`:

```java
// PiSSA initialization (SVD-based)
LoraConfig pissaConfig = LoraConfig.builder()
    .r(16)
    .loraAlpha(16)
    .targetModules(List.of("q_proj", "v_proj"))
    .initLoraWeights("pissa")    // triggers SVD-based init in LoraLayer
    .build();

// OLoRA initialization (QR-based, faster for wide matrices)
LoraConfig oloraConfig = LoraConfig.builder()
    .r(16)
    .loraAlpha(16)
    .targetModules(List.of("q_proj", "v_proj"))
    .initLoraWeights("olora")    // triggers QR-based init in LoraLayer
    .build();

PeftModel peft = PeftModelFactory.fromConfig(baseModel, pissaConfig);
```

Because the base weight is modified in place during initialization, `mergeAndUnload()` is mandatory before saving if you want a standalone model file. The residual base weight and the adapter must be kept together.

### 4.10 BitFit — Bias-Only Fine-Tuning

BitFit (ADR 0077) is the simplest PEFT method: every parameter in the base model is frozen except the bias vectors. It is implemented as a method on `PeftModel` rather than a separate config:

```java
// Freeze everything, then unfreeze only bias vectors
PeftModel peft = PeftModelFactory.fromConfig(baseModel,
    LoraConfig.builder().r(0).targetModules(List.of()).build());  // no-op LoRA
peft.applyBitFit();  // re-unfreezes all bias-named variables

peft.printTrainableParameters();
// trainable params: ~65,536  (only biases across all layers)
```

BitFit is appropriate for small datasets where larger adapters overfit, and for establishing a low-cost baseline before committing to a full PEFT experiment.

### 4.11 Multi-Adapter Serving

Multiple adapters may be attached to the same base model simultaneously. Switch between them without reloading the base weights:

```java
PeftModel peft = PeftModelFactory.fromConfig(baseModel, loraConfigA);

// Attach a second adapter while keeping the first
peft.addAdapter("task-b", loraConfigB);

// Switch to task-b for inference
peft.setActiveAdapter("task-b");
INDArray output = peft.output(input);

// Switch back to task-a
peft.setActiveAdapter("task-a");
```

`LoraAdapterCache` manages hot-swapping in multi-tenant serving: the cache holds pre-loaded adapter weight maps and swaps them into the model graph on demand with LRU eviction.

### 4.12 Merge and Unload

After training, merge the adapter into the base weights to get a standard model with no inference overhead:

```java
// Merge adapter weights into base weights, return plain SameDiff
SameDiff mergedModel = peft.mergeAndUnload();

// Save as a normal model — no adapter overhead at inference
mergedModel.save(Paths.get("merged-model.fb"), true);
```

Under the hood, `mergeAndUnload()` calls `LoraLayer.mergeWeights()` on every adapted layer, which adds `(B @ A) * (alpha / r)` to the base weight in place, then removes the adapter variables from the graph.

***

## 5. RL Alignment Trainers

All RL alignment trainers operate on a `SameDiff` policy model and are configured with builder-pattern config objects. They implement a common interface that exposes a `train(MultiDataSetIterator)` method.

### 5.1 GRPOTrainer — Group Relative Policy Optimization

For each prompt, GRPO generates `groupSize` completions, scores them with a reward function, computes z-score normalized advantages within the group, and optimizes a clipped surrogate objective:

```
L = -min(ratio * adv, clip(ratio, 1-eps, 1+eps) * adv) + kl_coeff * KL(pi || ref)
```

where `ratio = pi(token) / pi_old(token)`.

SameDiff placeholders used by GRPOTrainer:

| Placeholder           | Shape          | Description                                       |
| --------------------- | -------------- | ------------------------------------------------- |
| `_grpo_tokens`        | `[batch, seq]` | Token IDs for generated completions               |
| `_grpo_old_log_probs` | `[batch, seq]` | Log probabilities from the old policy             |
| `_grpo_advantages`    | `[batch]`      | Z-score normalized group advantages               |
| `_grpo_ref_log_probs` | `[batch, seq]` | Log probabilities from the frozen reference model |

```java
import org.deeplearning4j.rl.GRPOTrainer;
import org.deeplearning4j.rl.config.GRPOConfig;

GRPOConfig grpoConfig = GRPOConfig.builder()
    .groupSize(8)            // completions per prompt
    .clipEpsilon(0.2)        // ratio clipping threshold
    .klCoeff(0.04)           // KL penalty coefficient
    .maxNewTokens(256)
    .temperature(0.9)
    .build();

GRPOTrainer trainer = new GRPOTrainer(peftModel, refModel, rewardFn, grpoConfig);
trainer.train(promptIterator);
```

### 5.2 DPOTrainer — Direct Preference Optimization

DPO eliminates the need for a separate reward model by directly optimizing on preference pairs (chosen vs. rejected):

```
L = -log(sigmoid(beta * (log_pi_chosen - log_pi_ref_chosen
                        - log_pi_rejected + log_pi_ref_rejected)))
```

```java
import org.deeplearning4j.rl.DPOTrainer;
import org.deeplearning4j.rl.config.DPOConfig;

DPOConfig dpoConfig = DPOConfig.builder()
    .beta(0.1)               // temperature for preference scaling
    .labelSmoothing(0.0)
    .lossType(DPOLossType.SIGMOID)
    .build();

DPOTrainer trainer = new DPOTrainer(peftModel, refModel, dpoConfig);
trainer.train(preferenceDataIterator);  // iterator yields (prompt, chosen, rejected) triples
```

### 5.3 PPOTrainer — Proximal Policy Optimization

PPO uses a value network to estimate baselines and clips the policy update ratio:

```java
import org.deeplearning4j.rl.PPOTrainer;
import org.deeplearning4j.rl.config.PPOConfig;

PPOConfig ppoConfig = PPOConfig.builder()
    .clipEpsilon(0.2)
    .valueClipEpsilon(0.2)
    .entropyCoeff(0.01)
    .valueCoeff(0.5)
    .numEpochs(4)            // PPO epochs per rollout
    .miniBatchSize(8)
    .build();

PPOTrainer trainer = new PPOTrainer(policyModel, valueModel, rewardFn, ppoConfig);
trainer.train(promptIterator);
```

### 5.4 DAPOTrainer — Decoupled Clip and Dynamic Sampling Policy Optimization

DAPO decouples the clip range for chosen vs. rejected samples and uses dynamic sampling to focus the rollout budget on difficult prompts:

```java
import org.deeplearning4j.rl.DAPOTrainer;

DAPOConfig dapoConfig = DAPOConfig.builder()
    .clipEpsilonHigh(0.3)    // clip for chosen completions
    .clipEpsilonLow(0.1)     // clip for rejected completions
    .dynamicSamplingTemp(0.8)
    .groupSize(8)
    .build();

DAPOTrainer trainer = new DAPOTrainer(peftModel, refModel, rewardFn, dapoConfig);
```

### 5.5 DrGRPOTrainer — Variance-Reduced GRPO

DrGRPO adds advantage whitening (zero mean, unit variance across the full batch rather than per-group) to reduce gradient variance:

```java
import org.deeplearning4j.rl.DrGRPOTrainer;

DrGRPOConfig config = DrGRPOConfig.builder()
    .groupSize(8)
    .clipEpsilon(0.2)
    .klCoeff(0.02)
    .whitenAdvantages(true)      // global whitening, not per-group z-score
    .build();

DrGRPOTrainer trainer = new DrGRPOTrainer(peftModel, refModel, rewardFn, config);
```

### 5.6 KTOTrainer — Kahneman-Tversky Optimization

KTO applies prospect theory loss functions that model asymmetric human preferences (losses are felt more strongly than gains):

```java
import org.deeplearning4j.rl.KTOTrainer;

KTOConfig config = KTOConfig.builder()
    .beta(0.1)
    .desirabilityThreshold(0.5)
    .undesirabilityWeight(1.0)
    .desirabilityWeight(1.0)
    .build();

KTOTrainer trainer = new KTOTrainer(peftModel, refModel, config);
```

### 5.7 ORPOTrainer — Odds Ratio Preference Optimization

ORPO eliminates the reference model by adding an odds-ratio penalty directly to the SFT cross-entropy loss:

```java
import org.deeplearning4j.rl.ORPOTrainer;

ORPOConfig config = ORPOConfig.builder()
    .lambda(0.1)             // weight of the OR penalty term
    .build();

ORPOTrainer trainer = new ORPOTrainer(peftModel, config);
// No reference model needed
trainer.train(preferenceIterator);
```

### 5.8 SimPOTrainer — Simple Preference Optimization

SimPO removes the reference model and uses length-normalized rewards to avoid verbosity bias:

```java
import org.deeplearning4j.rl.SimPOTrainer;

SimPOConfig config = SimPOConfig.builder()
    .beta(2.5)
    .gamma(1.4)              // reward margin threshold
    .lengthNormalize(true)
    .build();

SimPOTrainer trainer = new SimPOTrainer(peftModel, config);
```

### 5.9 GSPOTrainer — Grouped Sampling Policy Optimization

GSPO uses sampling-based grouping of completions and optimizes group-level preference signals:

```java
import org.deeplearning4j.rl.GSPOTrainer;

GSPOConfig config = GSPOConfig.builder()
    .groupSize(4)
    .samplingTemperature(1.0)
    .klCoeff(0.01)
    .build();
```

### 5.10 VlmGRPOTrainer — GRPO for Vision-Language Models

`VlmGRPOTrainer` extends GRPO for multimodal models. The reward function receives both the generated text and the conditioning image:

```java
import org.deeplearning4j.rl.VlmGRPOTrainer;

VlmGRPOConfig config = VlmGRPOConfig.builder()
    .groupSize(4)
    .clipEpsilon(0.2)
    .imageRewardCoeff(0.5)       // weight for image-conditioned reward
    .textRewardCoeff(0.5)
    .build();

VlmGRPOTrainer trainer = new VlmGRPOTrainer(vlmPeftModel, refModel,
                                             imageTextRewardFn, config);
trainer.train(visionLanguageIterator);
```

### 5.11 RewardModelTrainer

`RewardModelTrainer` trains a scalar reward model from preference data. It supports three reward function implementations:

| Class                     | Description                                                     |
| ------------------------- | --------------------------------------------------------------- |
| `CompositeRewardFunction` | Weighted sum of multiple reward signals                         |
| `RuleBasedRewardFunction` | Hand-crafted heuristics (format compliance, length penalties)   |
| `SameDiffRewardFunction`  | Trainable neural reward model implemented as a `SameDiff` graph |

```java
import org.deeplearning4j.rl.RewardModelTrainer;

CompositeRewardFunction rewardFn = new CompositeRewardFunction()
    .add(new RuleBasedRewardFunction()
             .penalizeLength(maxLen, penalty)
             .requireJsonFormat(0.5), weight: 0.3)
    .add(new SameDiffRewardFunction(rewardModel), weight: 0.7);

RewardModelTrainer rmTrainer = new RewardModelTrainer(rewardFn, rmConfig);
rmTrainer.train(preferenceIterator);
```

***

## 6. Training Pipelines

### 6.1 SFTTrainingPipeline

`SFTTrainingPipeline` handles the full supervised fine-tuning loop: data loading, loss computation, gradient accumulation, optimizer stepping, checkpointing, and optional evaluation:

```java
import org.deeplearning4j.training.SFTTrainingPipeline;
import org.deeplearning4j.training.config.SFTConfig;

SFTConfig sftConfig = SFTConfig.builder()
    .maxSeqLength(2048)
    .batchSize(4)
    .gradientAccumulationSteps(8)    // effective batch = 4 * 8 = 32
    .numEpochs(3)
    .learningRate(2e-4)
    .warmupSteps(100)
    .saveEveryNSteps(500)
    .checkpointDir(Paths.get("/checkpoints/sft"))
    .evalEveryNSteps(200)
    .build();

SFTTrainingPipeline pipeline = new SFTTrainingPipeline(peftModel, sftConfig);
pipeline.train(trainIterator, evalIterator);
```

Internally, `SFTTrainingPipeline` uses `GradientAccumulator` to split each logical batch into micro-batches, accumulating gradients across `accumulationSteps` calls to `SameDiff.execBackwards()` before applying the optimizer update.

### 6.2 RLAlignmentPipeline

`RLAlignmentPipeline` coordinates the full RLHF loop:

1. Freeze reference model (copy of base policy before RL training)
2. Rollout generation — sample completions from current policy
3. Reward scoring — score with `RewardModelTrainer` or a `RewardFunction`
4. Advantage computation — normalize per group or globally
5. Trainer update — call the configured RL trainer (GRPO, PPO, DPO, etc.)

```java
import org.deeplearning4j.training.RLAlignmentPipeline;
import org.deeplearning4j.training.config.RLAlignmentConfig;

RLAlignmentConfig rlConfig = RLAlignmentConfig.builder()
    .trainer(TrainerType.GRPO)
    .grpoConfig(grpoConfig)
    .rewardFunction(compositeRewardFn)
    .rolloutBatchSize(16)
    .numRlSteps(1000)
    .klCoeff(0.04)
    .saveEveryNSteps(100)
    .checkpointDir(Paths.get("/checkpoints/rl"))
    .build();

RLAlignmentPipeline rlPipeline = new RLAlignmentPipeline(peftModel, rlConfig);
rlPipeline.train(promptIterator);
```

### 6.3 GradientAccumulator

```java
import org.deeplearning4j.training.GradientAccumulator;

GradientAccumulator accumulator = new GradientAccumulator(accumulationSteps /* 8 */);

for (MultiDataSet microBatch : microBatches) {
    accumulator.accumulate(sd, microBatch);   // calls sd.execBackwards() per micro-batch
}
accumulator.applyUpdate(sd, optimizer);       // applies accumulated gradients once
accumulator.reset();
```

### 6.4 CheckpointManager

```java
import org.deeplearning4j.training.CheckpointManager;

CheckpointManager ckpt = CheckpointManager.builder()
    .checkpointDir(Paths.get("/checkpoints"))
    .saveEveryNSteps(500)
    .keepLast(3)                         // delete older checkpoints automatically
    .saveOnBestMetric("eval_loss", true) // also save when validation loss improves
    .build();

// Register with a pipeline or call manually
ckpt.maybeCheckpoint(sd, step, metrics);
```

### 6.5 ContinuedPretrainingWorkflow

For domain-specific continued pretraining, `ContinuedPretrainingWorkflow` mixes a domain dataset with a general-purpose dataset at a configurable ratio:

```java
import org.deeplearning4j.training.ContinuedPretrainingWorkflow;

ContinuedPretrainingWorkflow workflow = ContinuedPretrainingWorkflow.builder()
    .domainIterator(domainIter)
    .generalIterator(generalIter)
    .domainMixRatio(0.7)              // 70% domain, 30% general
    .peftConfig(loraConfig)
    .sftConfig(sftConfig)
    .build();

workflow.run();
```

***

## 7. Mixed Precision Training

### 7.1 FP8ScaleManager

`FP8ScaleManager` tracks per-tensor absolute maximums and updates dynamic scales for FP8 forward and backward passes:

| Tensor direction    | FP8 format | Max representable value |
| ------------------- | ---------- | ----------------------- |
| Forward activations | E4M3       | 448.0                   |
| Backward gradients  | E5M2       | 57344.0                 |

```java
import org.deeplearning4j.training.precision.FP8ScaleManager;

FP8ScaleManager scaleManager = new FP8ScaleManager(rollingWindowSize /* 16 */);

// Called before each forward pass
INDArray activationScale = scaleManager.getForwardScale("layer_name", activation);

// Called before each backward pass
INDArray gradScale = scaleManager.getBackwardScale("layer_name", gradient);

// Update rolling absmax after each step
scaleManager.update("layer_name", absmax);
```

### 7.2 LossScaler — Dynamic Loss Scaling

`LossScaler` prevents underflow in FP16/BF16 training by scaling the loss up before the backward pass and scaling gradients back down before the optimizer step:

```java
import org.deeplearning4j.training.precision.LossScaler;

LossScaler scaler = new LossScaler(
    initialScale   /* 65536.0 */,
    scaleWindow    /* 2000 */,    // double scale every 2000 steps without overflow
    scaleFactor    /* 2.0 */      // halve scale on NaN/Inf gradient detection
);

// Scale loss before backward
SDVariable scaledLoss = scaler.scale(loss);
sd.execBackwards(scaledLoss);

// Unscale and check for NaN/Inf before optimizer step
boolean validStep = scaler.unscaleAndCheck(sd.getGradients());
if (validStep) {
    optimizer.step();
}
scaler.update(validStep);
```

### 7.3 Gradient Checkpointing

Gradient checkpointing trades compute for memory by not storing all intermediate activations during the forward pass. Activations are recomputed during the backward pass as needed:

```java
// Enable on a SameDiff model before training
sd.enableGradientCheckpointing(true);

// Fine-grained control: checkpoint every N layers
sd.setGradientCheckpointingInterval(4);
```

### 7.4 LossScaleConfig — TrainingConfig Integration (ADR 0057)

The core of the mixed precision system is `LossScaleConfig`, which is embedded directly in `TrainingConfig`. Three modes are available:

| Mode      | Behaviour                                                                           |
| --------- | ----------------------------------------------------------------------------------- |
| `NONE`    | No loss scaling (default; pure FP32 training)                                       |
| `STATIC`  | Fixed scale factor; use when you already know a safe scale for your model           |
| `DYNAMIC` | Adaptive scaling that automatically grows on stable steps and backs off on overflow |

`TrainingConfig` carries two additional data-type fields:

* `computeDataType` — the dtype used for forward and backward passes (e.g., `FLOAT16` or `BFLOAT16`)
* `masterWeightDataType` — the dtype used to store the authoritative weight copy (almost always `FLOAT`)

```java
import org.nd4j.autodiff.samediff.TrainingConfig;
import org.nd4j.autodiff.samediff.config.LossScaleConfig;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.learning.config.Adam;

// Dynamic mixed-precision training (recommended starting point)
TrainingConfig config = TrainingConfig.builder()
    .updater(new Adam(1e-3))
    .computeDataType(DataType.FLOAT16)          // FP16 forward/backward
    .masterWeightDataType(DataType.FLOAT)        // FP32 master weights
    .lossScaling(LossScaleConfig.builder()
        .mode(LossScaleConfig.Mode.DYNAMIC)
        .initialScale(65536.0)                   // 2^16 starting scale
        .growthFactor(2.0)                       // double scale on stable run
        .backoffFactor(0.5)                      // halve scale on overflow
        .growthInterval(2000)                    // steps before scale increase
        .minScale(1.0)
        .maxScale(65536.0)
        .build())
    .dataSetFeatureMapping("input")
    .dataSetLabelMapping("label")
    .build();

sd.setTrainingConfig(config);
sd.fit(trainData, numEpochs);   // SameDiff handles scale/unscale internally
```

For a known-stable scale (e.g., after a dry-run sweep), use static mode:

```java
TrainingConfig config = TrainingConfig.builder()
    .updater(new Adam(1e-3))
    .computeDataType(DataType.FLOAT16)
    .lossScaling(LossScaleConfig.builder()
        .mode(LossScaleConfig.Mode.STATIC)
        .initialScale(1024.0)
        .build())
    .build();
```

A convenience factory method covers the most common case:

```java
.lossScaling(LossScaleConfig.dynamicDefault())
// equivalent to DYNAMIC mode, initialScale=65536, growthInterval=2000
```

### 7.5 Gradient Accumulation via TrainingConfig (ADR 0057)

Gradient accumulation is configured directly on `TrainingConfig` via `gradientAccumulationSteps`. Gradients are always accumulated in FP32 regardless of `computeDataType`, which prevents precision loss when summing many small gradients:

```java
// Effective batch size = micro-batch size * gradientAccumulationSteps
// e.g., 32 micro-batch * 4 steps = effective batch of 128
TrainingConfig config = TrainingConfig.builder()
    .updater(new Adam(1e-3))
    .computeDataType(DataType.BFLOAT16)
    .masterWeightDataType(DataType.FLOAT)
    .lossScaling(LossScaleConfig.builder()
        .mode(LossScaleConfig.Mode.DYNAMIC)
        .initialScale(32768.0)
        .build())
    .gradientAccumulationSteps(4)
    .dataSetFeatureMapping("input")
    .dataSetLabelMapping("label")
    .build();
```

The `GradientAccumulator` (see section 6.3) is used by `SFTTrainingPipeline` internally when `gradientAccumulationSteps > 1`. It stores accumulated gradients per named variable and averages them before applying the optimizer update. You may also use it directly in a custom training loop (see section 6.3).

**Data type memory tradeoffs** (for reference):

| Component              | FP32-only | Mixed Precision (FP16 compute)      |
| ---------------------- | --------- | ----------------------------------- |
| Weights                | W bytes   | W/2 (FP16) + W (master FP32) = 1.5W |
| Gradients              | W bytes   | W/2 bytes                           |
| Optimizer state (Adam) | 2W bytes  | 2W bytes                            |
| Activations            | A bytes   | A/2 bytes                           |

For models where activations dominate GPU memory (most large transformers), mixed precision provides a significant net reduction despite the FP32 master weight copy.

***

## 8. 8-bit Adam Optimizer

`Adam8bit` stores the first and second moment vectors (`m` and `v`) in INT8 with per-block absolute maximum quantization, reducing optimizer state memory by approximately 4x:

| Precision            | Memory for 7B parameter model |
| -------------------- | ----------------------------- |
| BF16 Adam (standard) | \~56 GB optimizer state       |
| Adam8bit             | \~14 GB optimizer state       |

Default block size is 2048 elements per quantization block (`DEFAULT_BLOCK_SIZE = 2048`).

```java
import org.nd4j.linalg.learning.config.Adam8bit;

// Use Adam8bit as a drop-in replacement for Adam in TrainingConfig
TrainingConfig config = TrainingConfig.builder()
    .updater(new Adam8bit(
        2e-4,      // learning rate
        0.9,       // beta1
        0.999,     // beta2
        1e-8,      // epsilon
        2048       // block size (default)
    ))
    .dataSetFeatureMapping("input_ids")
    .dataSetLabelMapping("labels")
    .lossVariables("lm_loss")
    .build();
```

Internally, `Adam8bitUpdater` performs per-block updates:

1. Dequantize INT8 `m` and `v` blocks to float using stored absmax values
2. Apply the standard Adam moment update in float
3. Compute the parameter update
4. Requantize updated `m` and `v` back to INT8

`BlockQuantizationUtils` provides the low-level INT8 quantize/dequantize operations used by the updater.

***

## 9. Knowledge Distillation

`DistillationTrainer` trains a smaller student model to mimic a larger teacher model. Three loss components are available and can be combined:

| Loss Class                  | What It Minimizes                                                               |
| --------------------------- | ------------------------------------------------------------------------------- |
| `DistillationKLLoss`        | KL divergence between teacher and student output distributions at temperature T |
| `AttentionDistillationLoss` | MSE between teacher and student attention weight matrices                       |
| `FeatureDistillationLoss`   | MSE between intermediate hidden state representations                           |

```java
import org.deeplearning4j.training.DistillationTrainer;
import org.deeplearning4j.training.config.DistillationConfig;

DistillationConfig distConfig = DistillationConfig.builder()
    .temperature(4.0)                           // soften teacher probabilities
    .klLossWeight(0.7)                          // weight for output KL loss
    .attentionLossWeight(0.2)                   // weight for attention map loss
    .featureLossWeight(0.1)                     // weight for hidden state loss
    .teacherLayersForFeature(List.of(8, 16, 24))   // teacher layers to distill from
    .studentLayersForFeature(List.of(2,  4,  6))   // corresponding student layers
    .hardLabelWeight(0.1)                       // weight for ground-truth cross-entropy
    .build();

DistillationTrainer distTrainer = new DistillationTrainer(
    teacherModel,    // large frozen teacher SameDiff
    studentModel,    // smaller student SameDiff (may be a PeftModel)
    distConfig
);

distTrainer.train(trainIterator, numEpochs);
```

The `temperature` parameter scales logits before the softmax: higher values produce softer probability distributions that expose more information about inter-class similarities.

Alternatively, use the `SameDiff` extension:

```java
studentSd.distillFrom(teacherSd, distConfig);
studentSd.fit(trainIter, numEpochs);
```

### 9.1 Native Distillation Loss Ops (ADR 0077)

Each of the three Java loss classes is backed by a C++ op (CPU and CUDA) registered in libnd4j. You can invoke them directly via `SameDiff` if you want to compose a custom distillation loss:

| Op Name                       | SameDiff call                                                    | Formula                                                                        |
| ----------------------------- | ---------------------------------------------------------------- | ------------------------------------------------------------------------------ |
| `distillation_kl_loss`        | `sd.math.distillationKLLoss(student, teacher, labels, T, alpha)` | `alpha * T^2 * KL(softmax(s/T) \|\| softmax(t/T)) + (1-alpha) * CE(s, labels)` |
| `feature_distillation_loss`   | `sd.math.featureDistillationLoss(studentHidden, teacherHidden)`  | `MSE(projection(studentHidden), teacherHidden)`                                |
| `attention_distillation_loss` | `sd.math.attentionDistillationLoss(studentAttn, teacherAttn)`    | `MSE(studentAttn, teacherAttn)` with head-count alignment                      |

Each op ships with a gradient (backward) implementation so they participate automatically in `sd.execBackwards()`.

Example — custom combined loss in a SameDiff graph:

```java
SDVariable klLoss = sd.math.distillationKLLoss(
    studentLogits,   // student model output logits
    teacherLogits,   // teacher model output logits (detached / no grad)
    groundTruth,     // integer label array
    4.0,             // temperature T
    0.7              // alpha: weight of the soft KL term
);

SDVariable featLoss = sd.math.featureDistillationLoss(
    studentHiddenLayer6,
    teacherHiddenLayer24
);

SDVariable totalLoss = klLoss.mul(0.8).add(featLoss.mul(0.2));
sd.setLossVariables("total_loss");
```

### 9.2 Self-Distillation with EMA Teacher

`DistillationTrainer` also supports self-distillation, where the teacher is an Exponential Moving Average (EMA) of the student's own weights. This stabilises training when no external large teacher is available:

```java
DistillationConfig selfDistConfig = DistillationConfig.builder()
    .distillationType(DistillationType.SELF_DISTILLATION)
    .emaDecay(0.999)                // EMA coefficient for teacher refresh
    .temperature(2.0)
    .klLossWeight(1.0)
    .build();

DistillationTrainer trainer = new DistillationTrainer(
    null,            // no external teacher; EMA of student is used
    studentModel,
    selfDistConfig
);
trainer.train(trainIterator, numEpochs);
```

***

## 10. Dataset Curation Toolkit

The dataset curation package provides 20 utilities for preparing high-quality training data.

### 10.1 Deduplication

`TextDeduplicator` uses MinHash Locality-Sensitive Hashing to detect and remove near-duplicate documents efficiently:

```java
import org.deeplearning4j.data.TextDeduplicator;
import org.deeplearning4j.data.MinHasher;

MinHasher hasher = new MinHasher(numHashFunctions /* 128 */, ngramSize /* 5 */);
TextDeduplicator deduplicator = new TextDeduplicator(hasher, similarityThreshold /* 0.85 */);

List<String> deduplicated = deduplicator.deduplicate(documents);
System.out.printf("Removed %d duplicates%n",
    documents.size() - deduplicated.size());
```

### 10.2 Benchmark Contamination Removal

`NGramDecontaminator` removes documents containing n-gram overlaps with benchmark evaluation sets:

```java
import org.deeplearning4j.data.NGramDecontaminator;

NGramDecontaminator decon = new NGramDecontaminator(
    benchmarkDocuments,
    ngramSize /* 13 */,
    overlapThreshold /* 0.2 */   // fraction of document n-grams that must overlap
);

List<String> clean = decon.filter(trainingDocuments);
```

### 10.3 Quality Filtering

`TextQualityFilter` applies heuristics to remove low-quality text:

```java
import org.deeplearning4j.data.TextQualityFilter;

TextQualityFilter filter = TextQualityFilter.builder()
    .minTokens(50)
    .maxTokens(8192)
    .maxSymbolToWordRatio(0.1)
    .maxLineRepetitionFraction(0.3)
    .minMeanWordLength(3.0)
    .removeHtml(true)
    .build();

List<String> quality = filter.filter(documents);
```

### 10.4 Curriculum Learning

`CurriculumIterator` wraps a dataset iterator and yields examples in order of difficulty, starting with easier examples and progressing to harder ones:

```java
import org.deeplearning4j.data.CurriculumIterator;
import org.deeplearning4j.data.DifficultyScorer;

DifficultyScorer scorer = new PerplexityDifficultyScorer(referenceModel);

CurriculumIterator curriculumIter = new CurriculumIterator(
    baseIterator,
    scorer,
    CurriculumSchedule.LINEAR,    // LINEAR | EXPONENTIAL | STEP
    numWarmupSteps /* 1000 */
);
```

### 10.5 Sequence Packing

`SequencePacker` bin-packs variable-length sequences into fixed-size windows, minimizing padding tokens and maximizing GPU utilization:

```java
import org.deeplearning4j.data.SequencePacker;
import org.deeplearning4j.data.PackedSequence;

SequencePacker packer = new SequencePacker(maxSeqLength /* 2048 */);
List<PackedSequence> packed = packer.pack(sequences);

// PackedSequence contains the concatenated tokens and a position-ids array
// to distinguish document boundaries within the packed window
```

### 10.6 Domain Mixing

`WeightedDataMixer` samples from multiple domain iterators with configurable weights:

```java
import org.deeplearning4j.data.WeightedDataMixer;

WeightedDataMixer mixer = WeightedDataMixer.builder()
    .addSource("code",    codeIterator,    0.30)
    .addSource("math",    mathIterator,    0.20)
    .addSource("general", generalIterator, 0.50)
    .build();
```

### 10.7 Padding-Free Batching

`LengthBucketingIterator` groups sequences into length buckets before batching, reducing within-batch padding. `PaddingFreeBatch` represents a batch where all padding has been eliminated by packing:

```java
import org.deeplearning4j.data.LengthBucketingIterator;

LengthBucketingIterator bucketIter = new LengthBucketingIterator(
    baseIterator,
    bucketBoundaries /* new int[]{128, 256, 512, 1024, 2048} */,
    batchSize /* 8 */
);
```

### 10.8 Instruction Formatting and Chat Templates

`InstructionDataFormatter` and `ChatTemplate` apply model-specific chat templates to raw instruction datasets:

```java
import org.deeplearning4j.data.InstructionDataFormatter;
import org.deeplearning4j.data.ChatTemplate;

ChatTemplate llamaTemplate = ChatTemplate.forModel("llama-3");

InstructionDataFormatter formatter = new InstructionDataFormatter(llamaTemplate);
List<String> formatted = formatter.format(rawInstructions);
// Applies system prompt, user/assistant turn markers, and EOS tokens
```

### 10.9 Stratified Splitting

`StratifiedSplitter` creates train/validation/test splits that preserve category distributions:

```java
import org.deeplearning4j.data.StratifiedSplitter;

StratifiedSplitter splitter = new StratifiedSplitter(
    trainFraction /* 0.80 */,
    valFraction   /* 0.10 */,
    testFraction  /* 0.10 */,
    seed          /* 42 */
);

StratifiedSplitter.Split split = splitter.split(labeledDocuments, labels);
// split.train(), split.validation(), split.test()
```

### 10.10 Package Reference and Full Class List

All curation classes live under `org.nd4j.linalg.dataset.curation` in the `nd4j-api` artifact, not `org.deeplearning4j.data`. Use these imports in your code:

| Sub-package        | Classes                                                                                                     |
| ------------------ | ----------------------------------------------------------------------------------------------------------- |
| `.dedup`           | `MinHasher`, `TextDeduplicator`                                                                             |
| `.decontamination` | `NGramDecontaminator`, `DecontaminationResult`                                                              |
| `.curriculum`      | `CurriculumIterator`, `DifficultyScorer`                                                                    |
| `.packing`         | `SequencePacker`, `PackedSequence`                                                                          |
| `.mixing`          | `WeightedDataMixer`                                                                                         |
| `.batching`        | `LengthBucketingIterator`, `PaddingFreeBatch`, `PaddingFreeBatchCollator`                                   |
| `.filtering`       | `TextQualityFilter`, `FilterResult`                                                                         |
| `.format`          | `ChatTemplate`, `InstructionDataFormatter`, `ConversationTurn`, `ConversationExtension`, `FormattedExample` |
| `.splitting`       | `StratifiedSplitter`, `SplitResult`                                                                         |
| (top-level)        | `RawTextDatasetIterator`, `TokenizedTextDataIterator`                                                       |

The correct imports for the examples in sections 10.1–10.9 follow this pattern:

```java
import org.nd4j.linalg.dataset.curation.dedup.MinHasher;
import org.nd4j.linalg.dataset.curation.dedup.TextDeduplicator;
import org.nd4j.linalg.dataset.curation.decontamination.NGramDecontaminator;
import org.nd4j.linalg.dataset.curation.curriculum.CurriculumIterator;
import org.nd4j.linalg.dataset.curation.curriculum.DifficultyScorer;
import org.nd4j.linalg.dataset.curation.packing.SequencePacker;
import org.nd4j.linalg.dataset.curation.packing.PackedSequence;
import org.nd4j.linalg.dataset.curation.mixing.WeightedDataMixer;
import org.nd4j.linalg.dataset.curation.batching.LengthBucketingIterator;
import org.nd4j.linalg.dataset.curation.format.ChatTemplate;
import org.nd4j.linalg.dataset.curation.format.InstructionDataFormatter;
import org.nd4j.linalg.dataset.curation.splitting.StratifiedSplitter;
```

**MinHasher constructor signature** (exact from source):

```java
// shingleSize: character n-gram size for shingling
// numHashes: number of hash functions (bands * rowsPerBand)
// seed: optional random seed (default 42)
new MinHasher(int shingleSize, int numHashes)
new MinHasher(int shingleSize, int numHashes, long seed)
```

**NGramDecontaminator** works in two steps — index benchmarks first, then decontaminate:

```java
import org.nd4j.linalg.dataset.curation.decontamination.NGramDecontaminator;

// Default: 13-gram word-level (matching GPT-3 decontamination standard)
NGramDecontaminator decon = new NGramDecontaminator();  // ngramSize=13, word-level

// Step 1: index benchmark texts once
decon.indexBenchmark(benchmarkTexts);

// Step 2: remove contaminated training examples
List<String> clean = decon.decontaminate(trainingTexts);
// Or, get a detailed result with contaminated indices
DecontaminationResult result = decon.check(trainingTexts);
System.out.println("Contaminated: " + result.getContaminatedIndices().size());
```

**CurriculumIterator** API (exact from source):

```java
import org.nd4j.linalg.dataset.curation.curriculum.CurriculumIterator;
import org.nd4j.linalg.dataset.curation.curriculum.CurriculumIterator.Strategy;
import org.nd4j.linalg.dataset.curation.curriculum.CurriculumIterator.Pacing;

// Strategy: EASY_TO_HARD | HARD_TO_EASY | MIXED
// Pacing:   LINEAR | EXPONENTIAL | STEP
CurriculumIterator<String> iter = new CurriculumIterator<>(
    data,
    CurriculumIterator.lengthScorer(),   // built-in scorer: score by string length
    Strategy.EASY_TO_HARD,
    Pacing.LINEAR,
    42L   // seed
);
```

**SequencePacker** uses First-Fit Decreasing bin packing and produces block-diagonal attention masks (one block per packed document) to prevent cross-document attention leakage:

```java
import org.nd4j.linalg.dataset.curation.packing.SequencePacker;
import org.nd4j.linalg.dataset.curation.packing.PackedSequence;

// separatorToken: token ID inserted between documents in a bin
// paddingToken:   token ID used for trailing padding (default 0)
SequencePacker packer = new SequencePacker(/* separatorToken */ 2, /* paddingToken */ 0);
List<PackedSequence> packed = packer.pack(tokenizedSequences, /* maxLength */ 2048);

// PackedSequence fields:
//   int[]     tokens       — concatenated token ids, padded to maxLength
//   int[]     segmentIds   — which document each position belongs to (-1 for sep/pad)
//   INDArray  attnMask     — [maxLength, maxLength] block-diagonal attention mask
```

**WeightedDataMixer** is a generic `Iterator<T>` that accepts any `Iterator<T>` sources. It supports optional temperature scaling (values < 1 sharpen the distribution toward the highest-weight source; values > 1 flatten it):

```java
import org.nd4j.linalg.dataset.curation.mixing.WeightedDataMixer;

Map<String, WeightedDataMixer.WeightedSource<String>> sources = new LinkedHashMap<>();
sources.put("code",    new WeightedDataMixer.WeightedSource<>(codeIter,    0.30));
sources.put("math",    new WeightedDataMixer.WeightedSource<>(mathIter,    0.20));
sources.put("general", new WeightedDataMixer.WeightedSource<>(generalIter, 0.50));

WeightedDataMixer<String> mixer = new WeightedDataMixer<>(sources, /* temperature */ 1.0, /* seed */ 42L);

while (mixer.hasNext()) {
    String example = mixer.next();
    // ...
}

// Inspect how many examples were consumed from each source
Map<String, Long> stats = mixer.getStats();
```

***

## 11. Transfer Learning API

The classical transfer learning API (`TransferLearning` / `TransferLearningHelper`) applies to `MultiLayerNetwork` and `ComputationGraph` models. For large generative models, prefer the `PeftModel` API described in sections 3–4.

### 11.1 Freeze Layers by Name Pattern

```java
import org.deeplearning4j.nn.transferlearning.TransferLearning;

ComputationGraph model = new TransferLearning.GraphBuilder(pretrainedGraph)
    .fineTuneConfiguration(new FineTuneConfiguration.Builder()
        .updater(new Adam(1e-4))
        .build())
    .setFeatureExtractor("encoder_layer_11")   // freeze through this layer
    .removeVertexKeepConnections("output_head")
    .addLayer("output_head",
        new OutputLayer.Builder()
            .nIn(768).nOut(numClasses)
            .activation(Activation.SOFTMAX)
            .build(),
        "encoder_layer_11")
    .build();
```

### 11.2 SameDiff-Based Freezing

For `SameDiff` models, freeze parameters by converting them to constants:

```java
// Freeze the embedding layer
sd.convertToConstant(sd.getVariable("token_embeddings"));
sd.convertToConstant(sd.getVariable("position_embeddings"));

// All remaining VARIABLE-type parameters will be trained
sd.fit(trainIter, numEpochs);

// Unfreeze later for a second fine-tuning phase
sd.convertToVariable(sd.getVariable("token_embeddings"));
```

### 11.3 TransferLearningHelper — Pre-computed Feature Cache

For large frozen feature extractors, pre-computing activations at the freeze boundary dramatically reduces training time:

```java
import org.deeplearning4j.nn.transferlearning.TransferLearningHelper;

// Featurize dataset once (runs forward through frozen layers)
TransferLearningHelper helper = new TransferLearningHelper(pretrainedNet, "block5_pool");

List<DataSet> featurized = new ArrayList<>();
while (trainIter.hasNext()) {
    featurized.add(helper.featurize(trainIter.next()));
}

// Train only the head on cached features
for (int epoch = 0; epoch < numEpochs; epoch++) {
    for (DataSet batch : featurized) {
        helper.fitFeaturized(batch);
    }
}
```

***

## 12. Configuration Reference

### LoraConfig

| Parameter         | Type           | Default             | Description                                                            |
| ----------------- | -------------- | ------------------- | ---------------------------------------------------------------------- |
| `r`               | `int`          | 8                   | Adapter rank                                                           |
| `loraAlpha`       | `int`          | 16                  | Scaling numerator; effective\_scale = loraAlpha / r                    |
| `loraDropout`     | `double`       | 0.0                 | Dropout on adapter input                                               |
| `targetModules`   | `List<String>` | `[]`                | Layer name substrings to target                                        |
| `bias`            | `String`       | `"none"`            | Which bias vectors to train: `"none"`, `"all"`, `"lora_only"`          |
| `initLoraWeights` | `String`       | `"kaiming_uniform"` | Init strategy: `"kaiming_uniform"`, `"gaussian"`, `"pissa"`, `"olora"` |
| `loraLrRatioB`    | `double`       | 1.0                 | LR multiplier for B matrix relative to A matrix (LoRA+ feature)        |
| `useRsLora`       | `boolean`      | `false`             | Use rank-stabilized scaling α/√r instead of α/r                        |

### VeraConfig

| Parameter        | Type           | Default | Description                                   |
| ---------------- | -------------- | ------- | --------------------------------------------- |
| `projectionRank` | `int`          | 64      | Rank of the shared random projection matrices |
| `targetModules`  | `List<String>` | `[]`    | Layer name substrings to target               |
| `loraDropout`    | `double`       | 0.0     | Dropout applied to the adapter path           |

### DyLoraConfig

| Parameter       | Type           | Default | Description                                 |
| --------------- | -------------- | ------- | ------------------------------------------- |
| `r`             | `int`          | 16      | Maximum rank (upper bound of dynamic range) |
| `minRank`       | `int`          | 1       | Minimum rank sampled during training        |
| `loraAlpha`     | `int`          | 16      | Scaling factor                              |
| `targetModules` | `List<String>` | `[]`    | Layer name substrings to target             |

### LossScaleConfig

| Parameter        | Type     | Default | Description                                          |
| ---------------- | -------- | ------- | ---------------------------------------------------- |
| `mode`           | `Mode`   | `NONE`  | `NONE`, `STATIC`, or `DYNAMIC`                       |
| `initialScale`   | `double` | 65536.0 | Starting loss scale (2^16)                           |
| `minScale`       | `double` | 1.0     | Floor for dynamic scaling                            |
| `maxScale`       | `double` | 65536.0 | Ceiling for dynamic scaling                          |
| `growthFactor`   | `double` | 2.0     | Multiplier applied to scale after stable run         |
| `backoffFactor`  | `double` | 0.5     | Multiplier applied to scale on overflow              |
| `growthInterval` | `int`    | 2000    | Consecutive finite-gradient steps before scale grows |

### GRPOConfig

| Parameter      | Type     | Default | Description                             |
| -------------- | -------- | ------- | --------------------------------------- |
| `groupSize`    | `int`    | 8       | Completions generated per prompt        |
| `clipEpsilon`  | `double` | 0.2     | Policy ratio clipping threshold         |
| `klCoeff`      | `double` | 0.04    | KL divergence penalty weight            |
| `maxNewTokens` | `int`    | 256     | Maximum tokens generated per completion |
| `temperature`  | `double` | 1.0     | Sampling temperature                    |

### SFTConfig

| Parameter                   | Type     | Default | Description                   |
| --------------------------- | -------- | ------- | ----------------------------- |
| `maxSeqLength`              | `int`    | 2048    | Maximum input sequence length |
| `batchSize`                 | `int`    | 8       | Micro-batch size per step     |
| `gradientAccumulationSteps` | `int`    | 1       | Steps before optimizer update |
| `numEpochs`                 | `int`    | 3       | Training epochs               |
| `learningRate`              | `double` | 2e-4    | Peak learning rate            |
| `warmupSteps`               | `int`    | 0       | Linear warmup steps           |
| `saveEveryNSteps`           | `int`    | 500     | Checkpoint frequency          |

### FP8ScaleManager

| Parameter           | Type        | Default | Description                        |
| ------------------- | ----------- | ------- | ---------------------------------- |
| `rollingWindowSize` | `int`       | 16      | Steps in the rolling absmax window |
| `forwardFormat`     | `FP8Format` | `E4M3`  | FP8 format for forward activations |
| `backwardFormat`    | `FP8Format` | `E5M2`  | FP8 format for backward gradients  |
| `forwardMaxValue`   | `double`    | 448.0   | Saturation value for E4M3          |
| `backwardMaxValue`  | `double`    | 57344.0 | Saturation value for E5M2          |

### LossScaler

| Parameter      | Type     | Default | Description                                  |
| -------------- | -------- | ------- | -------------------------------------------- |
| `initialScale` | `double` | 65536.0 | Starting loss scale                          |
| `scaleWindow`  | `int`    | 2000    | Steps before doubling the scale              |
| `scaleFactor`  | `double` | 2.0     | Factor used to halve (on overflow) or double |

### Adam8bit

| Parameter   | Type     | Default | Description                     |
| ----------- | -------- | ------- | ------------------------------- |
| `lr`        | `double` | 1e-3    | Learning rate                   |
| `beta1`     | `double` | 0.9     | First moment decay              |
| `beta2`     | `double` | 0.999   | Second moment decay             |
| `epsilon`   | `double` | 1e-8    | Numerical stability constant    |
| `blockSize` | `int`    | 2048    | Elements per quantization block |

### DistillationConfig

| Parameter             | Type     | Default | Description                             |
| --------------------- | -------- | ------- | --------------------------------------- |
| `temperature`         | `double` | 4.0     | Softmax temperature for teacher logits  |
| `klLossWeight`        | `double` | 1.0     | Weight for KL divergence loss component |
| `attentionLossWeight` | `double` | 0.0     | Weight for attention map distillation   |
| `featureLossWeight`   | `double` | 0.0     | Weight for hidden state distillation    |
| `hardLabelWeight`     | `double` | 0.0     | Weight for ground-truth cross-entropy   |

***

## See Also

* [SameDiff Training](/en-1.0.0-rewrite/nd4j/overview-2/training.md) — base training loop and `TrainingConfig`
* [Transfer Learning](/en-1.0.0-rewrite/deeplearning4j/multilayernetwork/transfer-learning.md) — `TransferLearning.Builder` for `MultiLayerNetwork` / `ComputationGraph`
* [OmniHub Model Hub](/en-1.0.0-rewrite/omnihub/overview.md) — downloading pretrained base models to fine-tune
* [SameDiff Serialization](/en-1.0.0-rewrite/nd4j/overview-2/serialization.md) — saving and loading models and adapter weights


---

# Agent Instructions
This documentation is published with GitBook. GitBook is the documentation platform designed so that both humans and AI agents can read, navigate, and reason over technical content effectively. Learn more at gitbook.com.

## Querying This Documentation
If you need additional information that is not directly available in this page, you can query the documentation dynamically by asking a question.

Perform an HTTP GET request on the current page URL with the `ask` query parameter:

```
GET https://deeplearning4j.konduit.ai/en-1.0.0-rewrite/deeplearning4j/peft-and-rl.md?ask=<question>
```

The question should be specific, self-contained, and written in natural language.
The response will contain a direct answer to the question and relevant excerpts and sources from the documentation.

Use this mechanism when the answer is not explicitly present in the current page, you need clarification or additional context, or you want to retrieve related documentation sections.
