gym库中类FilterObservation(ObservationWrapper)的理解

 

filter_observation.py模块中类

FilterObservation(ObservationWrapper) 的理解。

 

 

代码:

 

import copy

from gym import spaces
from gym import ObservationWrapper


class FilterObservation(ObservationWrapper):
    """Filter dictionary observations by their keys.
    
    Args:
        env: The environment to wrap.
        filter_keys: List of keys to be included in the observations.

    Raises:
        ValueError: If observation keys in not instance of None or
            iterable.
        ValueError: If any of the `filter_keys` are not included in
            the original `env`'s observation space
    
    """
    def __init__(self, env, filter_keys=None):
        super(FilterObservation, self).__init__(env)

        wrapped_observation_space = env.observation_space
        assert isinstance(wrapped_observation_space, spaces.Dict), (
            "FilterObservationWrapper is only usable with dict observations.")

        observation_keys = wrapped_observation_space.spaces.keys()

        if filter_keys is None:
            filter_keys = tuple(observation_keys)

        missing_keys = set(
            key for key in filter_keys if key not in observation_keys)

        if missing_keys:
            raise ValueError(
                "All the filter_keys must be included in the "
                "original obsrevation space.\n"
                "Filter keys: {filter_keys}\n"
                "Observation keys: {observation_keys}\n"
                "Missing keys: {missing_keys}".format(
                    filter_keys=filter_keys,
                    observation_keys=observation_keys,
                    missing_keys=missing_keys,
                ))

        self.observation_space = type(wrapped_observation_space)([
            (name, copy.deepcopy(space))
            for name, space in wrapped_observation_space.spaces.items()
            if name in filter_keys
        ])

        self._env = env
        self._filter_keys = tuple(filter_keys)

    def observation(self, observation):
        filter_observation = self._filter_observation(observation)
        return filter_observation

    def _filter_observation(self, observation):
        observation = type(observation)([
            (name, value)
            for name, value in observation.items()
            if name in self._filter_keys
        ])
        return observation

 

 

 

 

该类的一个前提要求是传入的内部env必须是状态空间属于spaces.Dict类的,如下:

        assert isinstance(wrapped_observation_space, spaces.Dict), (
            "FilterObservationWrapper is only usable with dict observations.")

该类的意思就是将传入的状态空间为spaces.Dict类型的env中的属于filter_keys的key保留下其他的不保留。

内部的包装类的所有key为env.observation_space.spaces.keys() 。

 

 

如果需要保留下来的key本身不存在与内部包装类中,则记录下来:

        missing_keys = set(
            key for key in filter_keys if key not in observation_keys)

并报错:

        if missing_keys:
            raise ValueError(
                "All the filter_keys must be included in the "
                "original obsrevation space.\n"
                "Filter keys: {filter_keys}\n"
                "Observation keys: {observation_keys}\n"
                "Missing keys: {missing_keys}".format(
                    filter_keys=filter_keys,
                    observation_keys=observation_keys,
                    missing_keys=missing_keys,
                ))

 

 

 

每次获得内部类的observation后都按照过滤的key对其进行处理将过滤后的observation向上传递:

    def observation(self, observation):
        filter_observation = self._filter_observation(observation)
        return filter_observation

    def _filter_observation(self, observation):
        observation = type(observation)([
            (name, value)
            for name, value in observation.items()
            if name in self._filter_keys
        ])
        return observation

 

 

 

=========================================

 

posted on 2022-03-21 17:29  Angry_Panda  阅读(214)  评论(0)    收藏  举报

导航