Source code for torchkit.utils.config

import os
from typing import Optional, Union

import yaml
from ml_collections import config_dict

ConfigDict = config_dict.ConfigDict
FrozenConfigDict = config_dict.FrozenConfigDict


# Reference: https://github.com/deepmind/jaxline/blob/master/jaxline/base_config.py
[docs]def validate_config( config: ConfigDict, base_config: ConfigDict, base_filename: str, ) -> None: """Ensures a config inherits from a base config. Args: config (ConfigDict): The child config. base_config (ConfigDict): The base config. base_filename (str): Path to python file containing base config definition. Raises: ValueError: if the base config contains keys that are not present in config. """ for key in base_config.keys(): if key not in config: raise ValueError( f"Key {key} missing from config. This config is required to have " f"keys: {list(base_config.keys())}. See {base_filename} for more " "details." ) if isinstance(base_config[key], ConfigDict) and config[key] is not None: validate_config(config[key], base_config[key], base_filename)
[docs]def dump_config(exp_dir: str, config: Union[ConfigDict, FrozenConfigDict]) -> None: """Dump a config to disk. Args: exp_dir (str): Path to the experiment directory. config (Union[ConfigDict, FrozenConfigDict]): The config to dump. """ if not os.path.exists(exp_dir): os.makedirs(exp_dir) # Note: No need to explicitly delete the previous config file as "w" will overwrite # the file if it already exists. with open(os.path.join(exp_dir, "config.yaml"), "w") as fp: yaml.dump(config.to_dict(), fp)
[docs]def load_config( exp_dir: str, config: Optional[ConfigDict] = None, freeze: bool = False, ) -> Optional[Union[ConfigDict, FrozenConfigDict]]: """Load a config from an experiment directory. Args: exp_dir (str): Path to the experiment directory. config (Optional[ConfigDict], optional): An optional config object to inplace update. If one isn't provided, a new config object is returned. Defaults to None. freeze (bool, optional): Whether to freeze the config. Defaults to False. Returns: Optional[Union[ConfigDict, FrozenConfigDict]]: The config file that was stored in the experiment directory. """ with open(os.path.join(exp_dir, "config.yaml"), "r") as fp: cfg = yaml.load(fp, Loader=yaml.FullLoader) # Inplace update the config if one is provided. if config is not None: config.update(cfg) return if freeze: return FrozenConfigDict(cfg) return ConfigDict(cfg)
[docs]def copy_config_and_replace( config: ConfigDict, update_dict: Optional[ConfigDict] = None, freeze: bool = False, ) -> Union[ConfigDict, FrozenConfigDict]: """Makes a copy of a config and optionally updates its values. Args: config (ConfigDict): The config to copy. update_dict (Optional[ConfigDict], optional): A config that will optionally update the copy. Defaults to None. freeze (bool, optional): Whether to freeze the config after the copy. Defaults to False. Returns: Union[ConfigDict, FrozenConfigDict]: A copy of the config. """ # Using the ConfigDict constructor leaves the `FieldReferences` untouched unlike # `ConfigDict.copy_and_resolve_references`. new_config = ConfigDict(config) if update_dict is not None: new_config.update(update_dict) if freeze: return FrozenConfigDict(new_config) return new_config