import tensorflow as tf import gym import numpy as np import random as rand SAVE_DIR = './cartpole/cartpole_model.ckpt' tf.reset_default_graph() sess = tf.Session() sess.run(tf.global_variables_initializer()) env = gym.make('CartPole-v1') input_size = env.observation_space.shape[0] output_size = env.action_space.n node_num = 400 X=tf.placeholder(dtype=tf.float32, shape=(1,input_size)) w1 = tf.get_variable('w1',shape=[input_size, node_num], initializer=tf.contrib.layers.xavier_initializer()) w2 = tf.get_variable('w2',shape=[node_num, node_num], initializer=tf.contrib.layers.xavier_initializer()) w3 = tf.get_variable('w3',shape=[node_num, output_size], initializer=tf.contrib.layers.xavier_initializer()) b1 = tf.Variable(tf.zeros([1],dtype=tf.float32)) b2 = tf.Variable(tf.zeros([1],dtype=tf.float32)) L1=tf.nn.relu(tf.add(tf.matmul(X, w1),b1)) L2=tf.nn.relu(tf.add(tf.matmul(L1, w2),b2)) Q_model = tf.matmul(L2, w3) Y=tf.placeholder(dtype=tf.float32, shape=(1, output_size)) saver = tf.train.Saver() saver.restore(sess, SAVE_DIR) reword_list=[] for episode in range(100): state = env.reset() reword_tot = 0 done = False count = 0 while not done : env.render() count += 1 state_t = np.reshape(state, [1, input_size]) Q = sess.run(Q_model, feed_dict={X: state_t}) action = np.argmax(Q) state_next, reword, done, none = env.step(action) reword_tot += reword state = state_next reword_list.append(reword_tot) print("episode:{}, count:{}, reword tot:{}, reword avg:{}" .format(episode, count, reword_tot, np.mean(reword_list)))