nshtrainer
Configuration-driven wrapper around PyTorch Lightning. Every component has a paired Config class using nshconfig.Config (Pydantic-based).
Import Convention
import nshtrainer as nt # Core: nt.Trainer, nt.TrainerConfig # Model: nt.LightningModuleBase # Data: nt.LightningDataModuleBase # Metric: nt.MetricConfig
Core Pattern
import nshconfig as C
from typing_extensions import override
# 1. Config class for hyperparameters
class MyModelConfig(C.Config):
hidden_size: int = 64
lr: float = 1e-3
# 2. Model subclass parameterized by config
class MyModel(nt.LightningModuleBase[MyModelConfig]):
@override
@classmethod
def hparams_cls(cls):
return MyModelConfig
def __init__(self, hparams: MyModelConfig):
super().__init__(hparams)
# Access config via self.hparams (type-safe)
# 3. Configure trainer
config = nt.TrainerConfig(
max_epochs=10,
accelerator="gpu",
primary_metric=nt.MetricConfig(name="val_loss", mode="min"),
).with_project_root("./outputs")
# 4. Train
trainer = nt.Trainer(config)
trainer.fit(model, train_dataloaders=..., val_dataloaders=...)
TrainerConfig
Root config composing all sub-configs. Builder methods: with_*() returns copy, *_() mutates in-place.
Key fields: max_epochs, accelerator, strategy, primary_metric, callbacks (dict of callback configs), loggers, checkpoint, precision, gradient_clip_val.
Registries
Extensible component registration via nshconfig.Registry + discriminated unions:
| Registry | Purpose | Example |
|---|---|---|
callback_registry | Custom callbacks | Subclass CallbackConfigBase |
optimizer_registry | Optimizers | Subclass OptimizerConfigBase |
accelerator_registry | Accelerators | Subclass config |
plugin_registry | Plugins | Subclass config |
Built-in Callbacks
EMA, early stopping, model checkpointing, gradient skipping, norm logging, learning rate monitoring, and more. Configure via TrainerConfig.callbacks dict.
Code Style Rules
- •
from __future__ import annotationsin every file - •Type hints on all parameters (modern syntax:
X | None,list[int]) - •
ruff formatbefore committing,basedpyrightfor type checking - •
loggingmodule only, neverprint() - •Google-style docstrings
- •Composition over inheritance
Detailed Documentation
For in-depth reference on specific topics, see: