brain

class nett.Brain(policy: Any | str, algorithm: str | OnPolicyAlgorithm | OffPolicyAlgorithm, encoder: Any | str | None = None, embedding_dim: int | None = None, reward: str = 'supervised', batch_size: int = 512, buffer_size: int = 2048, train_encoder: bool = False, seed: int = 12, custom_encoder_args: dict[str, str] = {})

Represents the brain of an agent.

The brain is made up of an encoder, policy, algorithm, reward function, and the hyperparameters determined for these components such as the batch and buffer sizes. It produces a trained model based on the environment data and the inputs received by the brain through the body.

Parameters:
  • policy (Any | str) – The network used for defining the value and action networks.

  • algorithm (str | OnPolicyAlgorithm | OffPolicyAlgorithm) – The optimization algorithm used for training the model.

  • encoder (Any | str, optional) – The network used to extract features from the observations. Defaults to None.

  • embedding_dim (int, optional) – The dimension of the embedding space of the encoder. Defaults to None.

  • reward (str, optional) – The type of reward used for training the brain. Defaults to “supervised”.

  • batch_size (int, optional) – The batch size used for training. Defaults to 512.

  • buffer_size (int, optional) – The buffer size used for training. Defaults to 2048.

  • train_encoder (bool, optional) – Whether to train the encoder or not. Defaults to False.

  • seed (int, optional) – The random seed used for training. Defaults to 12.

  • custom_encoder_args (dict[str, str], optional) – Custom arguments for the encoder. Defaults to {}.

Example

>>> from nett import Brain
>>> brain = Brain(policy='CnnPolicy', algorithm='PPO')
load(model_path: str | Path) OnPolicyAlgorithm | OffPolicyAlgorithm

Load a trained model.

Parameters:

model_path (str | Path) – The path to the trained model.

Returns:

The loaded model.

Return type:

OnPolicyAlgorithm | OffPolicyAlgorithm

plot_results(iterations: int, model_log_dir: Path, plots_dir: Path, name: str) None

Plot the training results.

Parameters:
  • iterations (int) – The number of training iterations.

  • model_log_dir (Path) – The directory containing the model logs.

  • plots_dir (Path) – The directory to save the plots.

  • name (str) – The name of the plot.

save(path: str) None

Save the trained model.

Parameters:

path (str) – The path to save the model.

save_encoder_policy_network(path: Path)

Saves the policy and feature extractor of the agent’s model.

This method saves the policy and feature extractor of the agent’s model to the specified paths. It first checks if the model is loaded, and if not, it prints an error message and returns. Otherwise, it saves the policy as a pickle file and the feature extractor as a PyTorch state dictionary.

Returns:

None

test(env, iterations, model_path: str, rec_path: str, index: int)

Test the brain.

Parameters:
  • env (gym.Env) – The environment used for testing.

  • iterations (int) – The number of testing iterations.

  • model_path (str) – The path to the trained model.

  • rec_path (str) – The path to save the test video.

  • index (int) – The index of the model to test, needed for tracking bar.

train(env: nett.Body, iterations: int, device_type: str, device: int, index: int, paths: dict[str, Path], save_checkpoints: bool, checkpoint_freq: int)

Train the brain.

Parameters:
  • env (nett.Body) – The environment used for training.

  • iterations (int) – The number of training iterations.

  • device_type (str) – The type of device used for training.

  • device (int) – The device index used for training.

  • index (int) – The index of the model to test, needed for tracking bar.

  • paths (dict[str, Path]) – The paths for saving logs, models, and plots.

  • save_checkpoints (bool) – Whether to save checkpoints or not.

  • checkpoint_freq (int) – The frequency of saving checkpoints.

Raises:

ValueError – If the environment fails the validation check.

Initializes the brain module.

nett.brain.get_encoder_dict() dict[str, str]

Returns a dictionary mapping encoder names to encoder class names.

Returns:

Dictionary mapping encoder names to encoder class names.

Return type:

dict[str, str]

nett.brain.list_algorithms() list[str]

Returns a list of all available policy algorithms.

Returns:

Set of algorithm names.

Return type:

list[str]

nett.brain.list_encoders() list[str]

Returns a list of all available encoders.

Returns:

List of encoder names.

Return type:

list[str]

nett.brain.list_policies() list[str]

Returns a list of all available policy models.

Returns:

Set of policy names.

Return type:

list[str]