import numpy as np
import matplotlib.pyplot as plt
class World:
def __init__(self):
self.nRows = 3
self.nCols = 4
self.stateObstacles = [5]
self.stateTerminals = [10, 11]
self.nStates = 12
self.nActions = 4
def _plot_world(self):
nStates = self.nStates
nRows = self.nRows
nCols = self.nCols
stateObstacles = self.stateObstacles
stateTerminals = self.stateTerminals
coord = [[0, 0], [nCols, 0], [nCols, nRows], [0, nRows], [0, 0]]
xs, ys = zip(*coord)
plt.plot(xs, ys, "black")
for i in stateObstacles:
(I, J) = np.unravel_index(i, shape=(nRows, nCols), order='F')
coord = [[J, nRows - I],
[J + 1, nRows - I],
[J + 1, nRows - I + 1],
[J, nRows - I + 1],
[J, nRows - I]]
xs, ys = zip(*coord)
plt.fill(xs, ys, "0.5")
plt.plot(xs, ys, "black")
for ind, i in enumerate(stateTerminals):
(I, J) = np.unravel_index(i, shape=(nRows, nCols), order='F')
coord = [[J, nRows - I],
[J + 1, nRows - I],
[J + 1, nRows - I + 1],
[J, nRows - I + 1],
[J, nRows - I]]
xs, ys = zip(*coord)
plt.fill(xs, ys, "0.8")
plt.plot(xs, ys, "black")
plt.plot(xs, ys, "black")
X, Y = np.meshgrid(range(nCols + 1), range(nRows + 1))
plt.plot(X, Y, 'k-')
plt.plot(X.transpose(), Y.transpose(), 'k-')
@staticmethod
def _truncate(n, decimals=0):
multiplier = 10 ** decimals
return int(n * multiplier) / multiplier
def plot(self):
nStates = self.nStates
nRows = self.nRows
nCols = self.nCols
self._plot_world()
states = range(1, nStates + 1)
k = 0
for i in range(nCols):
for j in range(nRows, 0, -1):
plt.text(i + 0.5, j - 0.5, str(states[k]), fontsize=26, horizontalalignment='center', verticalalignment='center')
k += 1
plt.title('MDP gridworld', size=16)
plt.axis("equal")
plt.axis("off")
plt.show()
def plot_value(self, valueFunction):
nRows = self.nRows
nCols = self.nCols
stateObstacles = self.stateObstacles
self._plot_world()
k = 0
for i in range(nCols):
for j in range(nRows, 0, -1):
if k + 1 not in stateObstacles:
plt.text(i + 0.5, j - 0.5, str(self._truncate(valueFunction[k], 3)), fontsize=26, horizontalalignment='center', verticalalignment='center')
k += 1
plt.title('MDP gridworld', size=16)
plt.axis("equal")
plt.axis("off")
plt.show()
def plot_policy(self, policy):
nStates = self.nStates
nActions = self.nActions
nRows = self.nRows
nCols = self.nCols
stateObstacles = self.stateObstacles
stateTerminals = self.stateTerminals
policy = policy.reshape(nRows, nCols, order="F").reshape(-1, 1)
X, Y = np.meshgrid(range(nCols + 1), range(nRows + 1))
X1 = X[:-1, :-1]
Y1 = Y[:-1, :-1]
X2 = X1.reshape(-1, 1) + 0.5
Y2 = np.flip(Y1.reshape(-1, 1)) + 0.5
X2 = np.kron(np.ones((1, nActions)), X2)
Y2 = np.kron(np.ones((1, nActions)), Y2)
mat = np.cumsum(np.ones((nStates, nActions)), axis=1).astype("int64")
if policy.shape[1] == 1:
policy = (np.kron(np.ones((1, nActions)), policy) == mat)
index_no_policy = stateObstacles + stateTerminals
index_policy = [item - 1 for item in range(1, nStates + 1) if item not in index_no_policy]
mask = policy.astype("int64") * mat
mask = mask.reshape(nRows, nCols, nCols)
X3 = X2.reshape(nRows, nCols, nActions)
Y3 = Y2.reshape(nRows, nCols, nActions)
alpha = np.pi - np.pi / 2 * mask
self._plot_world()
for ii in index_policy:
ax = plt.gca()
j = int(ii / nRows)
i = (ii + 1 - j * nRows) % nCols - 1
index = np.where(mask[i, j] > 0)[0]
h = ax.quiver(X3[i, j, index], Y3[i, j, index], np.cos(alpha[i, j, index]), np.sin(alpha[i, j, index]), 0.3)
states = range(1, nStates + 1)
k = 0
for i in range(nCols):
for j in range(nRows, 0, -1):
plt.text(i + 0.25, j - 0.25, str(states[k]), fontsize=16, horizontalalignment='right', verticalalignment='bottom')
k += 1
plt.axis("equal")
plt.axis("off")
plt.show()