algorithm-in-python/math/numericalAnalysis/solve-linear-by-iteration.py

152 lines
4.0 KiB
Python
Raw Normal View History

2018-10-02 21:24:06 +08:00
''' mbinary
#########################################################################
# File : solve-linear-by-iteration.py
# Author: mbinary
# Mail: zhuheqin1@gmail.com
2019-01-31 12:09:46 +08:00
# Blog: https://mbinary.xyz
2018-10-02 21:24:06 +08:00
# Github: https://github.com/mbinary
# Created Time: 2018-10-02 21:14
# Description:
#########################################################################
'''
'''
#########################################################################
# File : solve-linear-by-iteration.py
# Author: mbinary
# Mail: zhuheqin1@gmail.com
2019-01-31 12:09:46 +08:00
# Blog: https://mbinary.xyz
2018-10-02 21:24:06 +08:00
# Github: https://github.com/mbinary
# Created Time: 2018-05-04 07:42
# Description:
#########################################################################
'''
2020-04-15 12:28:20 +08:00
import numpy as np
from operator import le, lt
def jacob(A, b, x, accuracy=None, times=6):
2018-10-02 21:24:06 +08:00
''' Ax=b, arg x is the init val, times is the time of iterating'''
2020-04-15 12:28:20 +08:00
A, b, x = np.matrix(A), np.matrix(b), np.matrix(x)
n, m = A.shape
if n != m:
raise Exception("Not square matrix: {A}".format(A=A))
if b.shape != (n, 1):
raise Exception(
'Error: {b} must be {n} x1 in dimension'.format(b=b, n=n))
2018-10-02 21:24:06 +08:00
D = np.diag(np.diag(A))
2020-04-15 12:28:20 +08:00
DI = np.zeros([n, n])
for i in range(n):
DI[i, i] = 1/D[i, i]
2018-10-02 21:24:06 +08:00
R = np.eye(n) - DI * A
g = DI * b
print('R =\n{}'.format(R))
print('g =\n{}'.format(g))
last = -x
if accuracy != None:
2020-04-15 12:28:20 +08:00
ct = 0
2018-10-02 21:24:06 +08:00
while 1:
2020-04-15 12:28:20 +08:00
ct += 1
2018-10-02 21:24:06 +08:00
tmp = x-last
last = x
2020-04-15 12:28:20 +08:00
mx = max(abs(i) for i in tmp)
if mx < accuracy:
return x
2018-10-02 21:24:06 +08:00
x = R*x+g
2020-04-15 12:28:20 +08:00
print('x{ct} =\n{x}'.format(ct=ct, x=x))
2018-10-02 21:24:06 +08:00
else:
for i in range(times):
x = R*x+g
2020-04-15 12:28:20 +08:00
print('x{ct} = \n{x}'.format(ct=i+1, x=x))
2018-10-02 21:24:06 +08:00
print('isLimitd: {}'.format(isLimited(A)))
return x
2020-04-15 12:28:20 +08:00
def gauss_seidel(A, b, x, accuracy=None, times=6):
2018-10-02 21:24:06 +08:00
''' Ax=b, arg x is the init val, times is the time of iterating'''
2020-04-15 12:28:20 +08:00
A, b, x = np.matrix(A), np.matrix(b), np.matrix(x)
n, m = A.shape
if n != m:
raise Exception("Not square matrix: {A}".format(A=A))
if b.shape != (n, 1):
raise Exception(
'Error: {b} must be {n} x1 in dimension'.format(b=b, n=n))
D = np. matrix(np.diag(np.diag(A)))
2018-10-02 21:24:06 +08:00
L = np.tril(A) - D # L = np.triu(D.T) - D
U = np.triu(A) - D
DLI = (D+L).I
S = - (DLI) * U
f = (DLI)*b
print('S =\n{}'.format(S))
print('f =\n{}'.format(f))
last = -x
if accuracy != None:
2020-04-15 12:28:20 +08:00
ct = 0
2018-10-02 21:24:06 +08:00
while 1:
2020-04-15 12:28:20 +08:00
ct += 1
2018-10-02 21:24:06 +08:00
tmp = x-last
last = x
2020-04-15 12:28:20 +08:00
mx = max(abs(i) for i in tmp)
if mx < accuracy:
return x
2018-10-02 21:24:06 +08:00
x = S*x+f
2020-04-15 12:28:20 +08:00
print('x{ct} =\n{x}'.format(ct=ct, x=x))
2018-10-02 21:24:06 +08:00
else:
for i in range(times):
x = S*x+f
2020-04-15 12:28:20 +08:00
print('x{ct} = \n{x}'.format(ct=i+1, x=x))
2018-10-02 21:24:06 +08:00
print('isLimitd: {}'.format(isLimited(A)))
return x
2020-04-15 12:28:20 +08:00
def isLimited(A, strict=False):
2018-10-02 21:24:06 +08:00
'''通过检查A是否是[严格]对角优来判断迭代是否收敛, 即对角线上的值是否都大于对应行(或者列)的值'''
diag = np.diag(A)
op = lt if strict else le
2020-04-15 12:28:20 +08:00
if op(A.max(axis=0), diag).all():
return True
if op(A.max(axis=1), diag).all():
return True
2018-10-02 21:24:06 +08:00
return False
2020-04-15 12:28:20 +08:00
testcase = []
2018-10-02 21:24:06 +08:00
def test():
2020-04-15 12:28:20 +08:00
for func, A, b, x, *args in testcase:
acc = None
2018-10-02 21:24:06 +08:00
times = 6
2020-04-15 12:28:20 +08:00
if args != []:
if isinstance(args[0], int):
times = args[0]
else:
acc = args[0]
return func(A, b, x, acc, times)
if __name__ == '__main__':
A = [[2, -1, -1],
[1, 5, -1],
[1, 1, 10]
]
b = [[-5], [8], [11]]
x = [[1], [1], [1]]
# testcase.append([gauss_seidel,A,b,x])
A = [[2, -1, 1], [3, 3, 9], [3, 3, 5]]
b = [[-1], [0], [4]]
x = [[0], [0], [0]]
# testcase.append([jacob,A,b,x])
A = [[5, -1, -1],
[3, 6, 2],
[1, -1, 2]
]
b = [[16], [11], [-2]]
x = [[1], [1], [-1]]
testcase.append([gauss_seidel, A, b, x, 0.001])
2018-10-02 21:24:06 +08:00
test()