mirror of
https://github.com/donnemartin/data-science-ipython-notebooks.git
synced 2024-03-22 13:30:56 +08:00
717 lines
18 KiB
Python
717 lines
18 KiB
Python
"""This file contains code for use with "Think Stats",
|
|
by Allen B. Downey, available from greenteapress.com
|
|
|
|
Copyright 2014 Allen B. Downey
|
|
License: GNU GPLv3 http://www.gnu.org/licenses/gpl.html
|
|
"""
|
|
|
|
from __future__ import print_function
|
|
|
|
import math
|
|
import matplotlib
|
|
import matplotlib.pyplot as pyplot
|
|
import numpy as np
|
|
import pandas
|
|
|
|
import warnings
|
|
|
|
# customize some matplotlib attributes
|
|
#matplotlib.rc('figure', figsize=(4, 3))
|
|
|
|
#matplotlib.rc('font', size=14.0)
|
|
#matplotlib.rc('axes', labelsize=22.0, titlesize=22.0)
|
|
#matplotlib.rc('legend', fontsize=20.0)
|
|
|
|
#matplotlib.rc('xtick.major', size=6.0)
|
|
#matplotlib.rc('xtick.minor', size=3.0)
|
|
|
|
#matplotlib.rc('ytick.major', size=6.0)
|
|
#matplotlib.rc('ytick.minor', size=3.0)
|
|
|
|
|
|
class _Brewer(object):
|
|
"""Encapsulates a nice sequence of colors.
|
|
|
|
Shades of blue that look good in color and can be distinguished
|
|
in grayscale (up to a point).
|
|
|
|
Borrowed from http://colorbrewer2.org/
|
|
"""
|
|
color_iter = None
|
|
|
|
colors = ['#081D58',
|
|
'#253494',
|
|
'#225EA8',
|
|
'#1D91C0',
|
|
'#41B6C4',
|
|
'#7FCDBB',
|
|
'#C7E9B4',
|
|
'#EDF8B1',
|
|
'#FFFFD9']
|
|
|
|
# lists that indicate which colors to use depending on how many are used
|
|
which_colors = [[],
|
|
[1],
|
|
[1, 3],
|
|
[0, 2, 4],
|
|
[0, 2, 4, 6],
|
|
[0, 2, 3, 5, 6],
|
|
[0, 2, 3, 4, 5, 6],
|
|
[0, 1, 2, 3, 4, 5, 6],
|
|
]
|
|
|
|
@classmethod
|
|
def Colors(cls):
|
|
"""Returns the list of colors.
|
|
"""
|
|
return cls.colors
|
|
|
|
@classmethod
|
|
def ColorGenerator(cls, n):
|
|
"""Returns an iterator of color strings.
|
|
|
|
n: how many colors will be used
|
|
"""
|
|
for i in cls.which_colors[n]:
|
|
yield cls.colors[i]
|
|
raise StopIteration('Ran out of colors in _Brewer.ColorGenerator')
|
|
|
|
@classmethod
|
|
def InitializeIter(cls, num):
|
|
"""Initializes the color iterator with the given number of colors."""
|
|
cls.color_iter = cls.ColorGenerator(num)
|
|
|
|
@classmethod
|
|
def ClearIter(cls):
|
|
"""Sets the color iterator to None."""
|
|
cls.color_iter = None
|
|
|
|
@classmethod
|
|
def GetIter(cls):
|
|
"""Gets the color iterator."""
|
|
if cls.color_iter is None:
|
|
cls.InitializeIter(7)
|
|
|
|
return cls.color_iter
|
|
|
|
|
|
def PrePlot(num=None, rows=None, cols=None):
|
|
"""Takes hints about what's coming.
|
|
|
|
num: number of lines that will be plotted
|
|
rows: number of rows of subplots
|
|
cols: number of columns of subplots
|
|
"""
|
|
if num:
|
|
_Brewer.InitializeIter(num)
|
|
|
|
if rows is None and cols is None:
|
|
return
|
|
|
|
if rows is not None and cols is None:
|
|
cols = 1
|
|
|
|
if cols is not None and rows is None:
|
|
rows = 1
|
|
|
|
# resize the image, depending on the number of rows and cols
|
|
size_map = {(1, 1): (8, 6),
|
|
(1, 2): (14, 6),
|
|
(1, 3): (14, 6),
|
|
(2, 2): (10, 10),
|
|
(2, 3): (16, 10),
|
|
(3, 1): (8, 10),
|
|
}
|
|
|
|
if (rows, cols) in size_map:
|
|
fig = pyplot.gcf()
|
|
fig.set_size_inches(*size_map[rows, cols])
|
|
|
|
# create the first subplot
|
|
if rows > 1 or cols > 1:
|
|
pyplot.subplot(rows, cols, 1)
|
|
global SUBPLOT_ROWS, SUBPLOT_COLS
|
|
SUBPLOT_ROWS = rows
|
|
SUBPLOT_COLS = cols
|
|
|
|
|
|
def SubPlot(plot_number, rows=None, cols=None):
|
|
"""Configures the number of subplots and changes the current plot.
|
|
|
|
rows: int
|
|
cols: int
|
|
plot_number: int
|
|
"""
|
|
rows = rows or SUBPLOT_ROWS
|
|
cols = cols or SUBPLOT_COLS
|
|
pyplot.subplot(rows, cols, plot_number)
|
|
|
|
|
|
def _Underride(d, **options):
|
|
"""Add key-value pairs to d only if key is not in d.
|
|
|
|
If d is None, create a new dictionary.
|
|
|
|
d: dictionary
|
|
options: keyword args to add to d
|
|
"""
|
|
if d is None:
|
|
d = {}
|
|
|
|
for key, val in options.items():
|
|
d.setdefault(key, val)
|
|
|
|
return d
|
|
|
|
|
|
def Clf():
|
|
"""Clears the figure and any hints that have been set."""
|
|
global LOC
|
|
LOC = None
|
|
_Brewer.ClearIter()
|
|
pyplot.clf()
|
|
fig = pyplot.gcf()
|
|
fig.set_size_inches(8, 6)
|
|
|
|
|
|
def Figure(**options):
|
|
"""Sets options for the current figure."""
|
|
_Underride(options, figsize=(6, 8))
|
|
pyplot.figure(**options)
|
|
|
|
|
|
def _UnderrideColor(options):
|
|
if 'color' in options:
|
|
return options
|
|
|
|
color_iter = _Brewer.GetIter()
|
|
|
|
if color_iter:
|
|
try:
|
|
options['color'] = next(color_iter)
|
|
except StopIteration:
|
|
# TODO: reconsider whether this should warn
|
|
# warnings.warn('Warning: Brewer ran out of colors.')
|
|
_Brewer.ClearIter()
|
|
return options
|
|
|
|
|
|
def Plot(obj, ys=None, style='', **options):
|
|
"""Plots a line.
|
|
|
|
Args:
|
|
obj: sequence of x values, or Series, or anything with Render()
|
|
ys: sequence of y values
|
|
style: style string passed along to pyplot.plot
|
|
options: keyword args passed to pyplot.plot
|
|
"""
|
|
options = _UnderrideColor(options)
|
|
label = getattr(obj, 'label', '_nolegend_')
|
|
options = _Underride(options, linewidth=3, alpha=0.8, label=label)
|
|
|
|
xs = obj
|
|
if ys is None:
|
|
if hasattr(obj, 'Render'):
|
|
xs, ys = obj.Render()
|
|
if isinstance(obj, pandas.Series):
|
|
ys = obj.values
|
|
xs = obj.index
|
|
|
|
if ys is None:
|
|
pyplot.plot(xs, style, **options)
|
|
else:
|
|
pyplot.plot(xs, ys, style, **options)
|
|
|
|
|
|
def FillBetween(xs, y1, y2=None, where=None, **options):
|
|
"""Plots a line.
|
|
|
|
Args:
|
|
xs: sequence of x values
|
|
y1: sequence of y values
|
|
y2: sequence of y values
|
|
where: sequence of boolean
|
|
options: keyword args passed to pyplot.fill_between
|
|
"""
|
|
options = _UnderrideColor(options)
|
|
options = _Underride(options, linewidth=0, alpha=0.5)
|
|
pyplot.fill_between(xs, y1, y2, where, **options)
|
|
|
|
|
|
def Bar(xs, ys, **options):
|
|
"""Plots a line.
|
|
|
|
Args:
|
|
xs: sequence of x values
|
|
ys: sequence of y values
|
|
options: keyword args passed to pyplot.bar
|
|
"""
|
|
options = _UnderrideColor(options)
|
|
options = _Underride(options, linewidth=0, alpha=0.6)
|
|
pyplot.bar(xs, ys, **options)
|
|
|
|
|
|
def Scatter(xs, ys=None, **options):
|
|
"""Makes a scatter plot.
|
|
|
|
xs: x values
|
|
ys: y values
|
|
options: options passed to pyplot.scatter
|
|
"""
|
|
options = _Underride(options, color='blue', alpha=0.2,
|
|
s=30, edgecolors='none')
|
|
|
|
if ys is None and isinstance(xs, pandas.Series):
|
|
ys = xs.values
|
|
xs = xs.index
|
|
|
|
pyplot.scatter(xs, ys, **options)
|
|
|
|
|
|
def HexBin(xs, ys, **options):
|
|
"""Makes a scatter plot.
|
|
|
|
xs: x values
|
|
ys: y values
|
|
options: options passed to pyplot.scatter
|
|
"""
|
|
options = _Underride(options, cmap=matplotlib.cm.Blues)
|
|
pyplot.hexbin(xs, ys, **options)
|
|
|
|
|
|
def Pdf(pdf, **options):
|
|
"""Plots a Pdf, Pmf, or Hist as a line.
|
|
|
|
Args:
|
|
pdf: Pdf, Pmf, or Hist object
|
|
options: keyword args passed to pyplot.plot
|
|
"""
|
|
low, high = options.pop('low', None), options.pop('high', None)
|
|
n = options.pop('n', 101)
|
|
xs, ps = pdf.Render(low=low, high=high, n=n)
|
|
options = _Underride(options, label=pdf.label)
|
|
Plot(xs, ps, **options)
|
|
|
|
|
|
def Pdfs(pdfs, **options):
|
|
"""Plots a sequence of PDFs.
|
|
|
|
Options are passed along for all PDFs. If you want different
|
|
options for each pdf, make multiple calls to Pdf.
|
|
|
|
Args:
|
|
pdfs: sequence of PDF objects
|
|
options: keyword args passed to pyplot.plot
|
|
"""
|
|
for pdf in pdfs:
|
|
Pdf(pdf, **options)
|
|
|
|
|
|
def Hist(hist, **options):
|
|
"""Plots a Pmf or Hist with a bar plot.
|
|
|
|
The default width of the bars is based on the minimum difference
|
|
between values in the Hist. If that's too small, you can override
|
|
it by providing a width keyword argument, in the same units
|
|
as the values.
|
|
|
|
Args:
|
|
hist: Hist or Pmf object
|
|
options: keyword args passed to pyplot.bar
|
|
"""
|
|
# find the minimum distance between adjacent values
|
|
xs, ys = hist.Render()
|
|
|
|
if 'width' not in options:
|
|
try:
|
|
options['width'] = 0.9 * np.diff(xs).min()
|
|
except TypeError:
|
|
warnings.warn("Hist: Can't compute bar width automatically."
|
|
"Check for non-numeric types in Hist."
|
|
"Or try providing width option."
|
|
)
|
|
|
|
options = _Underride(options, label=hist.label)
|
|
options = _Underride(options, align='center')
|
|
if options['align'] == 'left':
|
|
options['align'] = 'edge'
|
|
elif options['align'] == 'right':
|
|
options['align'] = 'edge'
|
|
options['width'] *= -1
|
|
|
|
Bar(xs, ys, **options)
|
|
|
|
|
|
def Hists(hists, **options):
|
|
"""Plots two histograms as interleaved bar plots.
|
|
|
|
Options are passed along for all PMFs. If you want different
|
|
options for each pmf, make multiple calls to Pmf.
|
|
|
|
Args:
|
|
hists: list of two Hist or Pmf objects
|
|
options: keyword args passed to pyplot.plot
|
|
"""
|
|
for hist in hists:
|
|
Hist(hist, **options)
|
|
|
|
|
|
def Pmf(pmf, **options):
|
|
"""Plots a Pmf or Hist as a line.
|
|
|
|
Args:
|
|
pmf: Hist or Pmf object
|
|
options: keyword args passed to pyplot.plot
|
|
"""
|
|
xs, ys = pmf.Render()
|
|
low, high = min(xs), max(xs)
|
|
|
|
width = options.pop('width', None)
|
|
if width is None:
|
|
try:
|
|
width = np.diff(xs).min()
|
|
except TypeError:
|
|
warnings.warn("Pmf: Can't compute bar width automatically."
|
|
"Check for non-numeric types in Pmf."
|
|
"Or try providing width option.")
|
|
points = []
|
|
|
|
lastx = np.nan
|
|
lasty = 0
|
|
for x, y in zip(xs, ys):
|
|
if (x - lastx) > 1e-5:
|
|
points.append((lastx, 0))
|
|
points.append((x, 0))
|
|
|
|
points.append((x, lasty))
|
|
points.append((x, y))
|
|
points.append((x+width, y))
|
|
|
|
lastx = x + width
|
|
lasty = y
|
|
points.append((lastx, 0))
|
|
pxs, pys = zip(*points)
|
|
|
|
align = options.pop('align', 'center')
|
|
if align == 'center':
|
|
pxs = np.array(pxs) - width/2.0
|
|
if align == 'right':
|
|
pxs = np.array(pxs) - width
|
|
|
|
options = _Underride(options, label=pmf.label)
|
|
Plot(pxs, pys, **options)
|
|
|
|
|
|
def Pmfs(pmfs, **options):
|
|
"""Plots a sequence of PMFs.
|
|
|
|
Options are passed along for all PMFs. If you want different
|
|
options for each pmf, make multiple calls to Pmf.
|
|
|
|
Args:
|
|
pmfs: sequence of PMF objects
|
|
options: keyword args passed to pyplot.plot
|
|
"""
|
|
for pmf in pmfs:
|
|
Pmf(pmf, **options)
|
|
|
|
|
|
def Diff(t):
|
|
"""Compute the differences between adjacent elements in a sequence.
|
|
|
|
Args:
|
|
t: sequence of number
|
|
|
|
Returns:
|
|
sequence of differences (length one less than t)
|
|
"""
|
|
diffs = [t[i+1] - t[i] for i in range(len(t)-1)]
|
|
return diffs
|
|
|
|
|
|
def Cdf(cdf, complement=False, transform=None, **options):
|
|
"""Plots a CDF as a line.
|
|
|
|
Args:
|
|
cdf: Cdf object
|
|
complement: boolean, whether to plot the complementary CDF
|
|
transform: string, one of 'exponential', 'pareto', 'weibull', 'gumbel'
|
|
options: keyword args passed to pyplot.plot
|
|
|
|
Returns:
|
|
dictionary with the scale options that should be passed to
|
|
Config, Show or Save.
|
|
"""
|
|
xs, ps = cdf.Render()
|
|
xs = np.asarray(xs)
|
|
ps = np.asarray(ps)
|
|
|
|
scale = dict(xscale='linear', yscale='linear')
|
|
|
|
for s in ['xscale', 'yscale']:
|
|
if s in options:
|
|
scale[s] = options.pop(s)
|
|
|
|
if transform == 'exponential':
|
|
complement = True
|
|
scale['yscale'] = 'log'
|
|
|
|
if transform == 'pareto':
|
|
complement = True
|
|
scale['yscale'] = 'log'
|
|
scale['xscale'] = 'log'
|
|
|
|
if complement:
|
|
ps = [1.0-p for p in ps]
|
|
|
|
if transform == 'weibull':
|
|
xs = np.delete(xs, -1)
|
|
ps = np.delete(ps, -1)
|
|
ps = [-math.log(1.0-p) for p in ps]
|
|
scale['xscale'] = 'log'
|
|
scale['yscale'] = 'log'
|
|
|
|
if transform == 'gumbel':
|
|
xs = xp.delete(xs, 0)
|
|
ps = np.delete(ps, 0)
|
|
ps = [-math.log(p) for p in ps]
|
|
scale['yscale'] = 'log'
|
|
|
|
options = _Underride(options, label=cdf.label)
|
|
Plot(xs, ps, **options)
|
|
return scale
|
|
|
|
|
|
def Cdfs(cdfs, complement=False, transform=None, **options):
|
|
"""Plots a sequence of CDFs.
|
|
|
|
cdfs: sequence of CDF objects
|
|
complement: boolean, whether to plot the complementary CDF
|
|
transform: string, one of 'exponential', 'pareto', 'weibull', 'gumbel'
|
|
options: keyword args passed to pyplot.plot
|
|
"""
|
|
for cdf in cdfs:
|
|
Cdf(cdf, complement, transform, **options)
|
|
|
|
|
|
def Contour(obj, pcolor=False, contour=True, imshow=False, **options):
|
|
"""Makes a contour plot.
|
|
|
|
d: map from (x, y) to z, or object that provides GetDict
|
|
pcolor: boolean, whether to make a pseudocolor plot
|
|
contour: boolean, whether to make a contour plot
|
|
imshow: boolean, whether to use pyplot.imshow
|
|
options: keyword args passed to pyplot.pcolor and/or pyplot.contour
|
|
"""
|
|
try:
|
|
d = obj.GetDict()
|
|
except AttributeError:
|
|
d = obj
|
|
|
|
_Underride(options, linewidth=3, cmap=matplotlib.cm.Blues)
|
|
|
|
xs, ys = zip(*d.keys())
|
|
xs = sorted(set(xs))
|
|
ys = sorted(set(ys))
|
|
|
|
X, Y = np.meshgrid(xs, ys)
|
|
func = lambda x, y: d.get((x, y), 0)
|
|
func = np.vectorize(func)
|
|
Z = func(X, Y)
|
|
|
|
x_formatter = matplotlib.ticker.ScalarFormatter(useOffset=False)
|
|
axes = pyplot.gca()
|
|
axes.xaxis.set_major_formatter(x_formatter)
|
|
|
|
if pcolor:
|
|
pyplot.pcolormesh(X, Y, Z, **options)
|
|
if contour:
|
|
cs = pyplot.contour(X, Y, Z, **options)
|
|
pyplot.clabel(cs, inline=1, fontsize=10)
|
|
if imshow:
|
|
extent = xs[0], xs[-1], ys[0], ys[-1]
|
|
pyplot.imshow(Z, extent=extent, **options)
|
|
|
|
|
|
def Pcolor(xs, ys, zs, pcolor=True, contour=False, **options):
|
|
"""Makes a pseudocolor plot.
|
|
|
|
xs:
|
|
ys:
|
|
zs:
|
|
pcolor: boolean, whether to make a pseudocolor plot
|
|
contour: boolean, whether to make a contour plot
|
|
options: keyword args passed to pyplot.pcolor and/or pyplot.contour
|
|
"""
|
|
_Underride(options, linewidth=3, cmap=matplotlib.cm.Blues)
|
|
|
|
X, Y = np.meshgrid(xs, ys)
|
|
Z = zs
|
|
|
|
x_formatter = matplotlib.ticker.ScalarFormatter(useOffset=False)
|
|
axes = pyplot.gca()
|
|
axes.xaxis.set_major_formatter(x_formatter)
|
|
|
|
if pcolor:
|
|
pyplot.pcolormesh(X, Y, Z, **options)
|
|
|
|
if contour:
|
|
cs = pyplot.contour(X, Y, Z, **options)
|
|
pyplot.clabel(cs, inline=1, fontsize=10)
|
|
|
|
|
|
def Text(x, y, s, **options):
|
|
"""Puts text in a figure.
|
|
|
|
x: number
|
|
y: number
|
|
s: string
|
|
options: keyword args passed to pyplot.text
|
|
"""
|
|
options = _Underride(options,
|
|
fontsize=16,
|
|
verticalalignment='top',
|
|
horizontalalignment='left')
|
|
pyplot.text(x, y, s, **options)
|
|
|
|
|
|
LEGEND = True
|
|
LOC = None
|
|
|
|
def Config(**options):
|
|
"""Configures the plot.
|
|
|
|
Pulls options out of the option dictionary and passes them to
|
|
the corresponding pyplot functions.
|
|
"""
|
|
names = ['title', 'xlabel', 'ylabel', 'xscale', 'yscale',
|
|
'xticks', 'yticks', 'axis', 'xlim', 'ylim']
|
|
|
|
for name in names:
|
|
if name in options:
|
|
getattr(pyplot, name)(options[name])
|
|
|
|
# looks like this is not necessary: matplotlib understands text loc specs
|
|
loc_dict = {'upper right': 1,
|
|
'upper left': 2,
|
|
'lower left': 3,
|
|
'lower right': 4,
|
|
'right': 5,
|
|
'center left': 6,
|
|
'center right': 7,
|
|
'lower center': 8,
|
|
'upper center': 9,
|
|
'center': 10,
|
|
}
|
|
|
|
global LEGEND
|
|
LEGEND = options.get('legend', LEGEND)
|
|
|
|
if LEGEND:
|
|
global LOC
|
|
LOC = options.get('loc', LOC)
|
|
pyplot.legend(loc=LOC)
|
|
|
|
|
|
def Show(**options):
|
|
"""Shows the plot.
|
|
|
|
For options, see Config.
|
|
|
|
options: keyword args used to invoke various pyplot functions
|
|
"""
|
|
clf = options.pop('clf', True)
|
|
Config(**options)
|
|
pyplot.show()
|
|
if clf:
|
|
Clf()
|
|
|
|
|
|
def Plotly(**options):
|
|
"""Shows the plot.
|
|
|
|
For options, see Config.
|
|
|
|
options: keyword args used to invoke various pyplot functions
|
|
"""
|
|
clf = options.pop('clf', True)
|
|
Config(**options)
|
|
import plotly.plotly as plotly
|
|
url = plotly.plot_mpl(pyplot.gcf())
|
|
if clf:
|
|
Clf()
|
|
return url
|
|
|
|
|
|
def Save(root=None, formats=None, **options):
|
|
"""Saves the plot in the given formats and clears the figure.
|
|
|
|
For options, see Config.
|
|
|
|
Args:
|
|
root: string filename root
|
|
formats: list of string formats
|
|
options: keyword args used to invoke various pyplot functions
|
|
"""
|
|
clf = options.pop('clf', True)
|
|
Config(**options)
|
|
|
|
if formats is None:
|
|
formats = ['pdf', 'eps']
|
|
|
|
try:
|
|
formats.remove('plotly')
|
|
Plotly(clf=False)
|
|
except ValueError:
|
|
pass
|
|
|
|
if root:
|
|
for fmt in formats:
|
|
SaveFormat(root, fmt)
|
|
if clf:
|
|
Clf()
|
|
|
|
|
|
def SaveFormat(root, fmt='eps'):
|
|
"""Writes the current figure to a file in the given format.
|
|
|
|
Args:
|
|
root: string filename root
|
|
fmt: string format
|
|
"""
|
|
filename = '%s.%s' % (root, fmt)
|
|
print('Writing', filename)
|
|
pyplot.savefig(filename, format=fmt, dpi=300)
|
|
|
|
|
|
# provide aliases for calling functons with lower-case names
|
|
preplot = PrePlot
|
|
subplot = SubPlot
|
|
clf = Clf
|
|
figure = Figure
|
|
plot = Plot
|
|
text = Text
|
|
scatter = Scatter
|
|
pmf = Pmf
|
|
pmfs = Pmfs
|
|
hist = Hist
|
|
hists = Hists
|
|
diff = Diff
|
|
cdf = Cdf
|
|
cdfs = Cdfs
|
|
contour = Contour
|
|
pcolor = Pcolor
|
|
config = Config
|
|
show = Show
|
|
save = Save
|
|
|
|
|
|
def main():
|
|
color_iter = _Brewer.ColorGenerator(7)
|
|
for color in color_iter:
|
|
print(color)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|