根据baselines库修改的运行输入参数的解析代码

 

如题:

def arg_parser():
    """
    Create an empty argparse.ArgumentParser.
    """
    import argparse
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument('--env', help='environment ID', type=str, default='Reacher-v2')
    parser.add_argument('--env_type', help='type of environment, used when the environment type cannot be automatically determined', type=str)
    parser.add_argument('--seed', help='RNG seed', type=int, default=None)
    parser.add_argument('--alg', help='Algorithm', type=str, default='ppo2')
    parser.add_argument('--num_timesteps', type=float, default=1e6),
    parser.add_argument('--network', help='network type (mlp, cnn, lstm, cnn_lstm, conv_only)', default=None)
    parser.add_argument('--gamestate', help='game state to load (so far only used in retro games)', default=None)
    parser.add_argument('--num_env', help='Number of environment copies being run in parallel. When not specified, set to number of cpus for Atari, and to 1 for Mujoco', default=None, type=int)
    parser.add_argument('--reward_scale', help='Reward scale factor. Default: 1.0', default=1.0, type=float)
    parser.add_argument('--save_path', help='Path to save trained model to', default=None, type=str)
    parser.add_argument('--save_video_interval', help='Save video every x steps (0 = disabled)', default=0, type=int)
    parser.add_argument('--save_video_length', help='Length of recorded video. Default: 200', default=200, type=int)
    parser.add_argument('--log_path', help='Directory to save learning curve data.', default=None, type=str)
    parser.add_argument('--play', default=False, action='store_true')

    return parser.parse_known_args()

def parse_unknown_args(args):
    """
    Parse arguments not consumed by arg parser into a dictionary
    """
    retval = {}
    preceded_by_key = False
    for arg in args:
        if arg.startswith('--'):
            if '=' in arg:
                key = arg.split('=')[0][2:]
                value = arg.split('=')[1]
                retval[key] = value
            else:
                key = arg[2:]
                preceded_by_key = True
        elif preceded_by_key:
            retval[key] = arg
            preceded_by_key = False

    return retval

def parse_cmdline_kwargs(args, unknown_args):
    '''
    convert a list of '='-spaced command-line arguments to a dictionary, evaluating python objects when possible
    '''
    def parse(v):

        assert isinstance(v, str)
        try:
            return eval(v)
        except (NameError, SyntaxError):
            return v

    args.__dict__.update({k: parse(v) for k,v in parse_unknown_args(unknown_args).items()})
    return args


args, unknown_args = arg_parser()

print(args)
args = parse_cmdline_kwargs(args, unknown_args)
print(args)

 

运行:

python test.py --aaa=me --xxx=11.11  --abc=True   --cde=1+99

 

解析结果:

Namespace(alg='ppo2', env='Reacher-v2', env_type=None, gamestate=None, log_path=None, network=None, num_env=None, num_timesteps=1000000.0, play=False, reward_scale=1.0, save_path=None, save_video_interval=0, save_video_length=200, seed=None)
Namespace(aaa='me', abc=True, alg='ppo2', cde=100, env='Reacher-v2', env_type=None, gamestate=None, log_path=None, network=None, num_env=None, num_timesteps=1000000.0, play=False, reward_scale=1.0, save_path=None, save_video_interval=0, save_video_length=200, seed=None, xxx=11.11)

 

=======================================

 

比较规范的运行参数解析的代码,方便后续代码中对参数的调用。

 

posted on 2022-04-03 16:00  Angry_Panda  阅读(58)  评论(0)    收藏  举报

导航