mirror of
https://github.com/heqin-zhu/algorithm.git
synced 2024-03-22 13:30:46 +08:00
152 lines
4.0 KiB
Python
152 lines
4.0 KiB
Python
''' mbinary
|
|
#########################################################################
|
|
# File : solve-linear-by-iteration.py
|
|
# Author: mbinary
|
|
# Mail: zhuheqin1@gmail.com
|
|
# Blog: https://mbinary.xyz
|
|
# Github: https://github.com/mbinary
|
|
# Created Time: 2018-10-02 21:14
|
|
# Description:
|
|
#########################################################################
|
|
'''
|
|
|
|
'''
|
|
#########################################################################
|
|
# File : solve-linear-by-iteration.py
|
|
# Author: mbinary
|
|
# Mail: zhuheqin1@gmail.com
|
|
# Blog: https://mbinary.xyz
|
|
# Github: https://github.com/mbinary
|
|
# Created Time: 2018-05-04 07:42
|
|
# Description:
|
|
#########################################################################
|
|
'''
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
from operator import le, lt
|
|
def jacob(A, b, x, accuracy=None, times=6):
|
|
''' Ax=b, arg x is the init val, times is the time of iterating'''
|
|
A, b, x = np.matrix(A), np.matrix(b), np.matrix(x)
|
|
n, m = A.shape
|
|
if n != m:
|
|
raise Exception("Not square matrix: {A}".format(A=A))
|
|
if b.shape != (n, 1):
|
|
raise Exception(
|
|
'Error: {b} must be {n} x1 in dimension'.format(b=b, n=n))
|
|
D = np.diag(np.diag(A))
|
|
DI = np.zeros([n, n])
|
|
for i in range(n):
|
|
DI[i, i] = 1/D[i, i]
|
|
R = np.eye(n) - DI * A
|
|
g = DI * b
|
|
print('R =\n{}'.format(R))
|
|
print('g =\n{}'.format(g))
|
|
last = -x
|
|
if accuracy != None:
|
|
ct = 0
|
|
while 1:
|
|
ct += 1
|
|
tmp = x-last
|
|
last = x
|
|
mx = max(abs(i) for i in tmp)
|
|
if mx < accuracy:
|
|
return x
|
|
x = R*x+g
|
|
print('x{ct} =\n{x}'.format(ct=ct, x=x))
|
|
else:
|
|
for i in range(times):
|
|
x = R*x+g
|
|
print('x{ct} = \n{x}'.format(ct=i+1, x=x))
|
|
print('isLimitd: {}'.format(isLimited(A)))
|
|
return x
|
|
|
|
|
|
def gauss_seidel(A, b, x, accuracy=None, times=6):
|
|
''' Ax=b, arg x is the init val, times is the time of iterating'''
|
|
A, b, x = np.matrix(A), np.matrix(b), np.matrix(x)
|
|
n, m = A.shape
|
|
if n != m:
|
|
raise Exception("Not square matrix: {A}".format(A=A))
|
|
if b.shape != (n, 1):
|
|
raise Exception(
|
|
'Error: {b} must be {n} x1 in dimension'.format(b=b, n=n))
|
|
D = np. matrix(np.diag(np.diag(A)))
|
|
L = np.tril(A) - D # L = np.triu(D.T) - D
|
|
U = np.triu(A) - D
|
|
DLI = (D+L).I
|
|
S = - (DLI) * U
|
|
f = (DLI)*b
|
|
print('S =\n{}'.format(S))
|
|
print('f =\n{}'.format(f))
|
|
last = -x
|
|
if accuracy != None:
|
|
ct = 0
|
|
while 1:
|
|
ct += 1
|
|
tmp = x-last
|
|
last = x
|
|
mx = max(abs(i) for i in tmp)
|
|
if mx < accuracy:
|
|
return x
|
|
x = S*x+f
|
|
print('x{ct} =\n{x}'.format(ct=ct, x=x))
|
|
else:
|
|
for i in range(times):
|
|
x = S*x+f
|
|
print('x{ct} = \n{x}'.format(ct=i+1, x=x))
|
|
print('isLimitd: {}'.format(isLimited(A)))
|
|
return x
|
|
|
|
|
|
def isLimited(A, strict=False):
|
|
'''通过检查A是否是[严格]对角优来判断迭代是否收敛, 即对角线上的值是否都大于对应行(或者列)的值'''
|
|
diag = np.diag(A)
|
|
op = lt if strict else le
|
|
if op(A.max(axis=0), diag).all():
|
|
return True
|
|
if op(A.max(axis=1), diag).all():
|
|
return True
|
|
return False
|
|
|
|
|
|
testcase = []
|
|
|
|
|
|
def test():
|
|
for func, A, b, x, *args in testcase:
|
|
acc = None
|
|
times = 6
|
|
if args != []:
|
|
if isinstance(args[0], int):
|
|
times = args[0]
|
|
else:
|
|
acc = args[0]
|
|
return func(A, b, x, acc, times)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
A = [[2, -1, -1],
|
|
[1, 5, -1],
|
|
[1, 1, 10]
|
|
]
|
|
b = [[-5], [8], [11]]
|
|
x = [[1], [1], [1]]
|
|
# testcase.append([gauss_seidel,A,b,x])
|
|
|
|
A = [[2, -1, 1], [3, 3, 9], [3, 3, 5]]
|
|
b = [[-1], [0], [4]]
|
|
x = [[0], [0], [0]]
|
|
# testcase.append([jacob,A,b,x])
|
|
|
|
A = [[5, -1, -1],
|
|
[3, 6, 2],
|
|
[1, -1, 2]
|
|
]
|
|
b = [[16], [11], [-2]]
|
|
x = [[1], [1], [-1]]
|
|
testcase.append([gauss_seidel, A, b, x, 0.001])
|
|
test()
|