import numpy as np
import matplotlib.pyplot as plt
L = 1
T = 1
m = 5
n = 5
h = L / m
k = T / n
b = 0.05
mu = k / h**2
c = b * mu
if c <= 0 or c >= 0.5:
print('Scheme is unstable')
v = np.zeros((m + 1, n + 1))
ic1 = lambda x: np.sin(np.pi * x)
for j in range(1, m + 2):
v[0, j - 1] = ic1((j - 1) * h)
b1 = lambda t: 0 # L.B.C
b2 = lambda t: 0 # R.B.C
for i in range(1, n + 2):
v[i - 1, 0] = b1((i - 1) * k)
v[i - 1, n] = b2((i - 1) * k)
for i in range(n):
for j in range(1, m):
v[i + 1, j] = (1 - 2 * b * mu) * v[i, j] + b * mu * v[i, j + 1] + b * mu * v[i, j - 1]
x = np.linspace(0, L, m + 1)
t = np.linspace(0, T, n + 1)
X, T = np.meshgrid(x, t)
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.plot_surface(X, T, v, cmap='viridis')
ax.set_xlabel('Space X')
ax.set_ylabel('Time T')
ax.set_zlabel('V')
plt.title('Python for Heat')
plt.show()
接上例,初始条件改为:
对于0≤x≤L
u(0,x)={2x2(1−x) if x<0.5否则
代码数值解:
import numpy as np
import matplotlib.pyplot as plt
L = 1
T = 1
m = 5
n = 5
h = L / m
k = T / n
b = 0.05
mu = k / h ** 2
c = b * mu
if c <= 0 or c >= 0.5:
print('Scheme is unstable')
v = np.zeros((m + 1, n + 1))
ic1 = lambda x: 2 * x
ic2 = lambda x: 2 * (1 - x)
x = np.linspace(0, L, m + 1)
x = np.linspace(0, L, m + 1)
for j in range(1, m + 2):
if x[j - 1] < 0.5:
v[0, j - 1] = ic1(x[j - 1])
else:
v[0, j - 1] = ic2(x[j - 1])
b1 = lambda t: 0 # L.B.C
b2 = lambda t: 0 # R.B.C
for i in range(1, n + 2):
v[i - 1, 0] = b1((i - 1) * k)
v[i - 1, n] = b2((i - 1) * k)
for i in range(n):
for j in range(1, m):
v[i + 1, j] = (1 - 2 * b * mu) * v[i, j] + b * mu * v[i, j + 1] + b * mu * v[i, j - 1]
# Visualization
x = np.linspace(0, L, m + 1)
t = np.linspace(0, T, n + 1)
X, T = np.meshgrid(x, t)
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.plot_surface(X, T, v, cmap='viridis')
ax.set_xlabel('Space X')
ax.set_ylabel('Time T')
ax.set_zlabel('V')
plt.title('Python for Heat ')
plt.show()