mirror of
https://github.com/donnemartin/data-science-ipython-notebooks.git
synced 2024-03-22 13:30:56 +08:00
38 lines
735 B
Python
38 lines
735 B
Python
|
import numpy as np
|
||
|
import matplotlib.pyplot as plt
|
||
|
from sklearn.linear_model import LinearRegression
|
||
|
|
||
|
|
||
|
def plot_linear_regression():
|
||
|
a = 0.5
|
||
|
b = 1.0
|
||
|
|
||
|
# x from 0 to 10
|
||
|
x = 30 * np.random.random(20)
|
||
|
|
||
|
# y = a*x + b with noise
|
||
|
y = a * x + b + np.random.normal(size=x.shape)
|
||
|
|
||
|
# create a linear regression classifier
|
||
|
clf = LinearRegression()
|
||
|
clf.fit(x[:, None], y)
|
||
|
|
||
|
# predict y from the data
|
||
|
x_new = np.linspace(0, 30, 100)
|
||
|
y_new = clf.predict(x_new[:, None])
|
||
|
|
||
|
# plot the results
|
||
|
ax = plt.axes()
|
||
|
ax.scatter(x, y)
|
||
|
ax.plot(x_new, y_new)
|
||
|
|
||
|
ax.set_xlabel('x')
|
||
|
ax.set_ylabel('y')
|
||
|
|
||
|
ax.axis('tight')
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
plot_linear_regression()
|
||
|
plt.show()
|