https://cloud.tencent.com/developer/article/2183273?from_column=20421&from=20421

 

  • 再者,我们需要放置文件了,文件放置有很多种方法
  • H:\Anaconda3-2020.02\envs\tf2\Lib\site-packages\gym\envs\classic_control  可以直接放到文件夹里和别的py程序在一起
  • H:\Anaconda3-2020.02\envs\tf2\Lib\site-packages\gym\envs\classic_control\myenv也可以单独创建一个文件夹放置【推荐这样。不容易以后混淆】

这里文件放置的目录会影响到之后gym注册情况代码添加

注意:这里不推荐把文件放到robotics、mujoco文件夹里,因为这里是gym机器人环境的编辑文件,我们放进去后在运行调试会出错{mujoco_py、mujoco提示未安装,搞搞这个就会挺麻烦的,不符合我们简单教学,之后会在补充这块创建}

3.注册自己的模拟器

再次确认我们的文件放置位置:H:\Anaconda3-2020.02\envs\tf2\Lib\site-packages\gym\envs\classic_control\myenv

  • 注册环境第一步

打开__init__.py文件 添加from gym.envs.classic_control.myenv.myenv import MyEnv

代码语言:javascript
代码运行次数:0
运行
AI代码解释
 
from gym.envs.classic_control.cartpole import CartPoleEnv
from gym.envs.classic_control.mountain_car import MountainCarEnv
from gym.envs.classic_control.continuous_mountain_car import Continuous_MountainCarEnv
from gym.envs.classic_control.pendulum import PendulumEnv
from gym.envs.classic_control.acrobot import AcrobotEnv
#下面一句是我们自己添加的
from gym.envs.classic_control.myenv.myenv import MyEnv

这里解释一下为什么这么添加:第一个:myenv是文件夹名字  第二个:myenv是py文件的文件名   第三个:MyEnv是在文件中定义的环境类名字

{再举个例子,如果你添加方式是H:\Anaconda3-2020.02\envs\tf2\Lib\site-packages\gym\envs\classic_control ,那么你在__init__.py文件添加如下

from gym.envs.classic_control.myenv import MyEnv}

  • 注册环境第二步

返回gym/envs目录,在该目录的__init__.py中注册环境:

添加自己环境,只需要把类命改成自己的即可,放置位置任意,建议放在# Classic下面,方面以后查找不混淆。如果需要调整参数也可以调整像我开头说的那样

代码语言:javascript
代码运行次数:0
运行
AI代码解释
 
#自己创建的环境
register(
    id='MyEnv-v0',
    entry_point='gym.envs.classic_control:MyEnv',
    max_episode_steps=200,
    reward_threshold=195.0,
)
 
register(
    id='CartPole-v0',
    entry_point='gym.envs.classic_control:CartPoleEnv',
    max_episode_steps=200,
    reward_threshold=195.0,
)

注意:MyEnv-v0中v0代表环境类的版本号,在定义类的的时候名字里可以不加,但是在id注册的时候要加,后面import的时候要加。

至此,就完成了环境的注册,就可以使用自定义的环境了!

4.测试环境

新建一个py文件,简单测试一下

代码语言:javascript
代码运行次数:0
运行
AI代码解释
 
import gym
 
env = gym.make('MyEnv-v0')
env.reset()
for _ in range(1000):
    env.render()
    env.step(env.action_space.sample()) # take a random action

结果如下:

平衡小车环境成功。

5.注意事项

注意:MyEnv-v0中v0代表环境类的版本号,在定义类的的时候名字里可以不加,但是在id注册的时候要加,后面import的时候要加。 注意:MyEnv-v0中v0代表环境类的版本号,在定义类的的时候名字里可以不加,但是在id注册的时候要加,后面import的时候要加。 注意:MyEnv-v0中v0代表环境类的版本号,在定义类的的时候名字里可以不加,但是在id注册的时候要加,后面import的时候要加。 重要事情说三遍!!!

 
 
 
 

image

 

image

 

在classic_control文件夹下的__init__.py文件中的内容如下:

image

 

 

 
 
 
------------
以下是经测试gym>=0.26版本的 自建环境文件,并按以上方法进行注册可用的。
 

 

 

import gym
# env = gym.make("GridWorld-v0",render_mode='rgb_array')
# env.reset()
# for i in range(200):
# action = env.action_space.sample()
# next_state, reward, done,info,_ = env.step(action)
# env.render()

# env = GridEnv(render_mode="human")
env = gym.make('GridWorld-v0', render_mode="human")
obs, info = env.reset()

for _ in range(20):
action = env.action_space.sample()
obs, reward, terminated, truncated, info = env.step(action)
env.render()
if terminated:
obs, info = env.reset()

env.close()

#以下是动画的平滑移动版!!

# import gym
# from gym import spaces
# import numpy as np
# import random
# import matplotlib.pyplot as plt
# import matplotlib.patches as patches
# from matplotlib import animation

# class GridEnv(gym.Env):
# metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 10}

# def __init__(self, render_mode=None):
# super().__init__()
# self.render_mode = render_mode

# # 状态空间
# self.states = [1,2,3,4,5,6,7,8]
# self.state = None

# # 动作空间
# self.actions = ['n','e','s','w']

# # 状态位置(用于渲染)
# self.x = [140,220,300,380,460,140,300,460]
# self.y = [250,250,250,250,250,150,150,150]

# # 终止状态
# self.terminate_states = {6:1, 7:1, 8:1}

# # 奖励
# self.rewards = {'1_s': -1.0, '3_s':1.0, '5_s':-1.0}

# # 状态转移
# self.t = {
# '1_s':6, '1_e':2,
# '2_w':1, '2_e':3,
# '3_s':7, '3_w':2, '3_e':4,
# '4_w':3, '4_e':5,
# '5_s':8, '5_w':4
# }

# self.gamma = 0.8

# # Gym spaces
# self.action_space = spaces.Discrete(len(self.actions))
# self.observation_space = spaces.Discrete(len(self.states))

# # 渲染相关
# self.fig = None
# self.ax = None
# self.frames = [] # 用于 rgb_array 模式收集帧
# self.robot_pos = (0,0)

# def reset(self, *, seed=None, options=None):
# super().reset(seed=seed)
# self.state = random.choice(self.states)
# self.frames = []
# self.robot_pos = (self.x[self.state-1], self.y[self.state-1])
# return self.state, {}

# def step(self, action):
# act = self.actions[action]
# key = f"{self.state}_{act}"
# next_state = self.t.get(key, self.state)
# reward = self.rewards.get(key, 0.0)
# terminated = next_state in self.terminate_states
# truncated = False

# # 平滑移动:插值生成中间帧
# start_x, start_y = self.robot_pos
# end_x, end_y = self.x[next_state-1], self.y[next_state-1]
# steps = 5 # 每格移动拆分成5帧
# for i in range(1, steps+1):
# interp_x = start_x + (end_x - start_x) * i / steps
# interp_y = start_y + (end_y - start_y) * i / steps
# self.robot_pos = (interp_x, interp_y)
# if self.render_mode is not None:
# self._render_frame()

# self.state = next_state
# self.robot_pos = (end_x, end_y)
# return next_state, reward, terminated, truncated, {}

# def _render_frame(self):
# if self.fig is None:
# self.fig, self.ax = plt.subplots(figsize=(6,4))
# if self.render_mode == "human":
# plt.ion()

# self.ax.clear()
# self.ax.set_xlim(0, 600)
# self.ax.set_ylim(0, 400)
# self.ax.set_aspect('equal')
# self.ax.axis('off')

# # 绘制网格
# for x in [100,180,260,340,420,500]:
# self.ax.plot([x,x],[100,300], color='black')
# for y in [100,200,300]:
# self.ax.plot([100,500],[y,y], color='black')

# # 绘制元素
# self.ax.add_patch(patches.Circle((300,150),20,color='gold')) # 金条
# self.ax.add_patch(patches.Circle((140,150),20,color='black')) # 骷髅1
# self.ax.add_patch(patches.Circle((460,150),20,color='black')) # 骷髅2
# self.ax.add_patch(patches.Circle(self.robot_pos,15,color='brown')) # 机器人

# if self.render_mode == "human":
# plt.pause(0.05)
# plt.show(block=False)
# elif self.render_mode == "rgb_array":
# self.fig.canvas.draw()
# image = np.frombuffer(self.fig.canvas.tostring_rgb(), dtype='uint8')
# image = image.reshape(self.fig.canvas.get_width_height()[::-1] + (3,))
# self.frames.append(image)

# def render(self):
# if self.robot_pos is None:
# return None
# self._render_frame()
# return None

# def save_video(self, filename="grid_env.mp4"):
# if len(self.frames) == 0:
# print("No frames to save.")
# return
# fig = plt.figure()
# plt.axis('off')
# im = plt.imshow(self.frames[0])
# def update(i):
# im.set_data(self.frames[i])
# return [im]
# ani = animation.FuncAnimation(fig, update, frames=len(self.frames), blit=True)
# ani.save(filename, fps=self.metadata["render_fps"])
# plt.close(fig)
# print(f"Video saved to {filename}")

# def close(self):
# if self.fig:
# plt.close(self.fig)
# self.fig = None
# self.ax = None