AgentSkillsCN

using-nshtrainer

基于配置的PyTorch Lightning封装器,配备类型安全的配置与注册表。在使用nshtrainer构建训练流水线、配置TrainerConfig或回调函数、创建LightningModuleBase子类,或通过注册表配置设置优化器/调度器/日志记录器时使用此功能。

SKILL.md
--- frontmatter
name: using-nshtrainer
description: Config-driven PyTorch Lightning wrapper with type-safe configs and registries. Use when building training pipelines with nshtrainer, configuring TrainerConfig or callbacks, creating LightningModuleBase subclasses, or setting up optimizers/schedulers/loggers via registry configs.

nshtrainer

Configuration-driven wrapper around PyTorch Lightning. Every component has a paired Config class using nshconfig.Config (Pydantic-based).

Import Convention

python
import nshtrainer as nt
# Core: nt.Trainer, nt.TrainerConfig
# Model: nt.LightningModuleBase
# Data: nt.LightningDataModuleBase
# Metric: nt.MetricConfig

Core Pattern

python
import nshconfig as C
from typing_extensions import override

# 1. Config class for hyperparameters
class MyModelConfig(C.Config):
    hidden_size: int = 64
    lr: float = 1e-3

# 2. Model subclass parameterized by config
class MyModel(nt.LightningModuleBase[MyModelConfig]):
    @override
    @classmethod
    def hparams_cls(cls):
        return MyModelConfig

    def __init__(self, hparams: MyModelConfig):
        super().__init__(hparams)
        # Access config via self.hparams (type-safe)

# 3. Configure trainer
config = nt.TrainerConfig(
    max_epochs=10,
    accelerator="gpu",
    primary_metric=nt.MetricConfig(name="val_loss", mode="min"),
).with_project_root("./outputs")

# 4. Train
trainer = nt.Trainer(config)
trainer.fit(model, train_dataloaders=..., val_dataloaders=...)

TrainerConfig

Root config composing all sub-configs. Builder methods: with_*() returns copy, *_() mutates in-place.

Key fields: max_epochs, accelerator, strategy, primary_metric, callbacks (dict of callback configs), loggers, checkpoint, precision, gradient_clip_val.

Registries

Extensible component registration via nshconfig.Registry + discriminated unions:

RegistryPurposeExample
callback_registryCustom callbacksSubclass CallbackConfigBase
optimizer_registryOptimizersSubclass OptimizerConfigBase
accelerator_registryAcceleratorsSubclass config
plugin_registryPluginsSubclass config

Built-in Callbacks

EMA, early stopping, model checkpointing, gradient skipping, norm logging, learning rate monitoring, and more. Configure via TrainerConfig.callbacks dict.

Code Style Rules

  • from __future__ import annotations in every file
  • Type hints on all parameters (modern syntax: X | None, list[int])
  • ruff format before committing, basedpyright for type checking
  • logging module only, never print()
  • Google-style docstrings
  • Composition over inheritance

Detailed Documentation

For in-depth reference on specific topics, see: