717 lines
18 KiB
Python
Raw Normal View History

"""This file contains code for use with "Think Stats",
by Allen B. Downey, available from greenteapress.com
Copyright 2014 Allen B. Downey
License: GNU GPLv3 http://www.gnu.org/licenses/gpl.html
"""
from __future__ import print_function
import math
import matplotlib
import matplotlib.pyplot as pyplot
import numpy as np
import pandas
import warnings
# customize some matplotlib attributes
#matplotlib.rc('figure', figsize=(4, 3))
#matplotlib.rc('font', size=14.0)
#matplotlib.rc('axes', labelsize=22.0, titlesize=22.0)
#matplotlib.rc('legend', fontsize=20.0)
#matplotlib.rc('xtick.major', size=6.0)
#matplotlib.rc('xtick.minor', size=3.0)
#matplotlib.rc('ytick.major', size=6.0)
#matplotlib.rc('ytick.minor', size=3.0)
class _Brewer(object):
"""Encapsulates a nice sequence of colors.
Shades of blue that look good in color and can be distinguished
in grayscale (up to a point).
Borrowed from http://colorbrewer2.org/
"""
color_iter = None
colors = ['#081D58',
'#253494',
'#225EA8',
'#1D91C0',
'#41B6C4',
'#7FCDBB',
'#C7E9B4',
'#EDF8B1',
'#FFFFD9']
# lists that indicate which colors to use depending on how many are used
which_colors = [[],
[1],
[1, 3],
[0, 2, 4],
[0, 2, 4, 6],
[0, 2, 3, 5, 6],
[0, 2, 3, 4, 5, 6],
[0, 1, 2, 3, 4, 5, 6],
]
@classmethod
def Colors(cls):
"""Returns the list of colors.
"""
return cls.colors
@classmethod
def ColorGenerator(cls, n):
"""Returns an iterator of color strings.
n: how many colors will be used
"""
for i in cls.which_colors[n]:
yield cls.colors[i]
raise StopIteration('Ran out of colors in _Brewer.ColorGenerator')
@classmethod
def InitializeIter(cls, num):
"""Initializes the color iterator with the given number of colors."""
cls.color_iter = cls.ColorGenerator(num)
@classmethod
def ClearIter(cls):
"""Sets the color iterator to None."""
cls.color_iter = None
@classmethod
def GetIter(cls):
"""Gets the color iterator."""
if cls.color_iter is None:
cls.InitializeIter(7)
return cls.color_iter
def PrePlot(num=None, rows=None, cols=None):
"""Takes hints about what's coming.
num: number of lines that will be plotted
rows: number of rows of subplots
cols: number of columns of subplots
"""
if num:
_Brewer.InitializeIter(num)
if rows is None and cols is None:
return
if rows is not None and cols is None:
cols = 1
if cols is not None and rows is None:
rows = 1
# resize the image, depending on the number of rows and cols
size_map = {(1, 1): (8, 6),
(1, 2): (14, 6),
(1, 3): (14, 6),
(2, 2): (10, 10),
(2, 3): (16, 10),
(3, 1): (8, 10),
}
if (rows, cols) in size_map:
fig = pyplot.gcf()
fig.set_size_inches(*size_map[rows, cols])
# create the first subplot
if rows > 1 or cols > 1:
pyplot.subplot(rows, cols, 1)
global SUBPLOT_ROWS, SUBPLOT_COLS
SUBPLOT_ROWS = rows
SUBPLOT_COLS = cols
def SubPlot(plot_number, rows=None, cols=None):
"""Configures the number of subplots and changes the current plot.
rows: int
cols: int
plot_number: int
"""
rows = rows or SUBPLOT_ROWS
cols = cols or SUBPLOT_COLS
pyplot.subplot(rows, cols, plot_number)
def _Underride(d, **options):
"""Add key-value pairs to d only if key is not in d.
If d is None, create a new dictionary.
d: dictionary
options: keyword args to add to d
"""
if d is None:
d = {}
for key, val in options.items():
d.setdefault(key, val)
return d
def Clf():
"""Clears the figure and any hints that have been set."""
global LOC
LOC = None
_Brewer.ClearIter()
pyplot.clf()
fig = pyplot.gcf()
fig.set_size_inches(8, 6)
def Figure(**options):
"""Sets options for the current figure."""
_Underride(options, figsize=(6, 8))
pyplot.figure(**options)
def _UnderrideColor(options):
if 'color' in options:
return options
color_iter = _Brewer.GetIter()
if color_iter:
try:
options['color'] = next(color_iter)
except StopIteration:
# TODO: reconsider whether this should warn
# warnings.warn('Warning: Brewer ran out of colors.')
_Brewer.ClearIter()
return options
def Plot(obj, ys=None, style='', **options):
"""Plots a line.
Args:
obj: sequence of x values, or Series, or anything with Render()
ys: sequence of y values
style: style string passed along to pyplot.plot
options: keyword args passed to pyplot.plot
"""
options = _UnderrideColor(options)
label = getattr(obj, 'label', '_nolegend_')
options = _Underride(options, linewidth=3, alpha=0.8, label=label)
xs = obj
if ys is None:
if hasattr(obj, 'Render'):
xs, ys = obj.Render()
if isinstance(obj, pandas.Series):
ys = obj.values
xs = obj.index
if ys is None:
pyplot.plot(xs, style, **options)
else:
pyplot.plot(xs, ys, style, **options)
def FillBetween(xs, y1, y2=None, where=None, **options):
"""Plots a line.
Args:
xs: sequence of x values
y1: sequence of y values
y2: sequence of y values
where: sequence of boolean
options: keyword args passed to pyplot.fill_between
"""
options = _UnderrideColor(options)
options = _Underride(options, linewidth=0, alpha=0.5)
pyplot.fill_between(xs, y1, y2, where, **options)
def Bar(xs, ys, **options):
"""Plots a line.
Args:
xs: sequence of x values
ys: sequence of y values
options: keyword args passed to pyplot.bar
"""
options = _UnderrideColor(options)
options = _Underride(options, linewidth=0, alpha=0.6)
pyplot.bar(xs, ys, **options)
def Scatter(xs, ys=None, **options):
"""Makes a scatter plot.
xs: x values
ys: y values
options: options passed to pyplot.scatter
"""
options = _Underride(options, color='blue', alpha=0.2,
s=30, edgecolors='none')
if ys is None and isinstance(xs, pandas.Series):
ys = xs.values
xs = xs.index
pyplot.scatter(xs, ys, **options)
def HexBin(xs, ys, **options):
"""Makes a scatter plot.
xs: x values
ys: y values
options: options passed to pyplot.scatter
"""
options = _Underride(options, cmap=matplotlib.cm.Blues)
pyplot.hexbin(xs, ys, **options)
def Pdf(pdf, **options):
"""Plots a Pdf, Pmf, or Hist as a line.
Args:
pdf: Pdf, Pmf, or Hist object
options: keyword args passed to pyplot.plot
"""
low, high = options.pop('low', None), options.pop('high', None)
n = options.pop('n', 101)
xs, ps = pdf.Render(low=low, high=high, n=n)
options = _Underride(options, label=pdf.label)
Plot(xs, ps, **options)
def Pdfs(pdfs, **options):
"""Plots a sequence of PDFs.
Options are passed along for all PDFs. If you want different
options for each pdf, make multiple calls to Pdf.
Args:
pdfs: sequence of PDF objects
options: keyword args passed to pyplot.plot
"""
for pdf in pdfs:
Pdf(pdf, **options)
def Hist(hist, **options):
"""Plots a Pmf or Hist with a bar plot.
The default width of the bars is based on the minimum difference
between values in the Hist. If that's too small, you can override
it by providing a width keyword argument, in the same units
as the values.
Args:
hist: Hist or Pmf object
options: keyword args passed to pyplot.bar
"""
# find the minimum distance between adjacent values
xs, ys = hist.Render()
if 'width' not in options:
try:
options['width'] = 0.9 * np.diff(xs).min()
except TypeError:
warnings.warn("Hist: Can't compute bar width automatically."
"Check for non-numeric types in Hist."
"Or try providing width option."
)
options = _Underride(options, label=hist.label)
options = _Underride(options, align='center')
if options['align'] == 'left':
options['align'] = 'edge'
elif options['align'] == 'right':
options['align'] = 'edge'
options['width'] *= -1
Bar(xs, ys, **options)
def Hists(hists, **options):
"""Plots two histograms as interleaved bar plots.
Options are passed along for all PMFs. If you want different
options for each pmf, make multiple calls to Pmf.
Args:
hists: list of two Hist or Pmf objects
options: keyword args passed to pyplot.plot
"""
for hist in hists:
Hist(hist, **options)
def Pmf(pmf, **options):
"""Plots a Pmf or Hist as a line.
Args:
pmf: Hist or Pmf object
options: keyword args passed to pyplot.plot
"""
xs, ys = pmf.Render()
low, high = min(xs), max(xs)
width = options.pop('width', None)
if width is None:
try:
width = np.diff(xs).min()
except TypeError:
warnings.warn("Pmf: Can't compute bar width automatically."
"Check for non-numeric types in Pmf."
"Or try providing width option.")
points = []
lastx = np.nan
lasty = 0
for x, y in zip(xs, ys):
if (x - lastx) > 1e-5:
points.append((lastx, 0))
points.append((x, 0))
points.append((x, lasty))
points.append((x, y))
points.append((x+width, y))
lastx = x + width
lasty = y
points.append((lastx, 0))
pxs, pys = zip(*points)
align = options.pop('align', 'center')
if align == 'center':
pxs = np.array(pxs) - width/2.0
if align == 'right':
pxs = np.array(pxs) - width
options = _Underride(options, label=pmf.label)
Plot(pxs, pys, **options)
def Pmfs(pmfs, **options):
"""Plots a sequence of PMFs.
Options are passed along for all PMFs. If you want different
options for each pmf, make multiple calls to Pmf.
Args:
pmfs: sequence of PMF objects
options: keyword args passed to pyplot.plot
"""
for pmf in pmfs:
Pmf(pmf, **options)
def Diff(t):
"""Compute the differences between adjacent elements in a sequence.
Args:
t: sequence of number
Returns:
sequence of differences (length one less than t)
"""
diffs = [t[i+1] - t[i] for i in range(len(t)-1)]
return diffs
def Cdf(cdf, complement=False, transform=None, **options):
"""Plots a CDF as a line.
Args:
cdf: Cdf object
complement: boolean, whether to plot the complementary CDF
transform: string, one of 'exponential', 'pareto', 'weibull', 'gumbel'
options: keyword args passed to pyplot.plot
Returns:
dictionary with the scale options that should be passed to
Config, Show or Save.
"""
xs, ps = cdf.Render()
xs = np.asarray(xs)
ps = np.asarray(ps)
scale = dict(xscale='linear', yscale='linear')
for s in ['xscale', 'yscale']:
if s in options:
scale[s] = options.pop(s)
if transform == 'exponential':
complement = True
scale['yscale'] = 'log'
if transform == 'pareto':
complement = True
scale['yscale'] = 'log'
scale['xscale'] = 'log'
if complement:
ps = [1.0-p for p in ps]
if transform == 'weibull':
xs = np.delete(xs, -1)
ps = np.delete(ps, -1)
ps = [-math.log(1.0-p) for p in ps]
scale['xscale'] = 'log'
scale['yscale'] = 'log'
if transform == 'gumbel':
xs = xp.delete(xs, 0)
ps = np.delete(ps, 0)
ps = [-math.log(p) for p in ps]
scale['yscale'] = 'log'
options = _Underride(options, label=cdf.label)
Plot(xs, ps, **options)
return scale
def Cdfs(cdfs, complement=False, transform=None, **options):
"""Plots a sequence of CDFs.
cdfs: sequence of CDF objects
complement: boolean, whether to plot the complementary CDF
transform: string, one of 'exponential', 'pareto', 'weibull', 'gumbel'
options: keyword args passed to pyplot.plot
"""
for cdf in cdfs:
Cdf(cdf, complement, transform, **options)
def Contour(obj, pcolor=False, contour=True, imshow=False, **options):
"""Makes a contour plot.
d: map from (x, y) to z, or object that provides GetDict
pcolor: boolean, whether to make a pseudocolor plot
contour: boolean, whether to make a contour plot
imshow: boolean, whether to use pyplot.imshow
options: keyword args passed to pyplot.pcolor and/or pyplot.contour
"""
try:
d = obj.GetDict()
except AttributeError:
d = obj
_Underride(options, linewidth=3, cmap=matplotlib.cm.Blues)
xs, ys = zip(*d.keys())
xs = sorted(set(xs))
ys = sorted(set(ys))
X, Y = np.meshgrid(xs, ys)
func = lambda x, y: d.get((x, y), 0)
func = np.vectorize(func)
Z = func(X, Y)
x_formatter = matplotlib.ticker.ScalarFormatter(useOffset=False)
axes = pyplot.gca()
axes.xaxis.set_major_formatter(x_formatter)
if pcolor:
pyplot.pcolormesh(X, Y, Z, **options)
if contour:
cs = pyplot.contour(X, Y, Z, **options)
pyplot.clabel(cs, inline=1, fontsize=10)
if imshow:
extent = xs[0], xs[-1], ys[0], ys[-1]
pyplot.imshow(Z, extent=extent, **options)
def Pcolor(xs, ys, zs, pcolor=True, contour=False, **options):
"""Makes a pseudocolor plot.
xs:
ys:
zs:
pcolor: boolean, whether to make a pseudocolor plot
contour: boolean, whether to make a contour plot
options: keyword args passed to pyplot.pcolor and/or pyplot.contour
"""
_Underride(options, linewidth=3, cmap=matplotlib.cm.Blues)
X, Y = np.meshgrid(xs, ys)
Z = zs
x_formatter = matplotlib.ticker.ScalarFormatter(useOffset=False)
axes = pyplot.gca()
axes.xaxis.set_major_formatter(x_formatter)
if pcolor:
pyplot.pcolormesh(X, Y, Z, **options)
if contour:
cs = pyplot.contour(X, Y, Z, **options)
pyplot.clabel(cs, inline=1, fontsize=10)
def Text(x, y, s, **options):
"""Puts text in a figure.
x: number
y: number
s: string
options: keyword args passed to pyplot.text
"""
options = _Underride(options,
fontsize=16,
verticalalignment='top',
horizontalalignment='left')
pyplot.text(x, y, s, **options)
LEGEND = True
LOC = None
def Config(**options):
"""Configures the plot.
Pulls options out of the option dictionary and passes them to
the corresponding pyplot functions.
"""
names = ['title', 'xlabel', 'ylabel', 'xscale', 'yscale',
'xticks', 'yticks', 'axis', 'xlim', 'ylim']
for name in names:
if name in options:
getattr(pyplot, name)(options[name])
# looks like this is not necessary: matplotlib understands text loc specs
loc_dict = {'upper right': 1,
'upper left': 2,
'lower left': 3,
'lower right': 4,
'right': 5,
'center left': 6,
'center right': 7,
'lower center': 8,
'upper center': 9,
'center': 10,
}
global LEGEND
LEGEND = options.get('legend', LEGEND)
if LEGEND:
global LOC
LOC = options.get('loc', LOC)
pyplot.legend(loc=LOC)
def Show(**options):
"""Shows the plot.
For options, see Config.
options: keyword args used to invoke various pyplot functions
"""
clf = options.pop('clf', True)
Config(**options)
pyplot.show()
if clf:
Clf()
def Plotly(**options):
"""Shows the plot.
For options, see Config.
options: keyword args used to invoke various pyplot functions
"""
clf = options.pop('clf', True)
Config(**options)
import plotly.plotly as plotly
url = plotly.plot_mpl(pyplot.gcf())
if clf:
Clf()
return url
def Save(root=None, formats=None, **options):
"""Saves the plot in the given formats and clears the figure.
For options, see Config.
Args:
root: string filename root
formats: list of string formats
options: keyword args used to invoke various pyplot functions
"""
clf = options.pop('clf', True)
Config(**options)
if formats is None:
formats = ['pdf', 'eps']
try:
formats.remove('plotly')
Plotly(clf=False)
except ValueError:
pass
if root:
for fmt in formats:
SaveFormat(root, fmt)
if clf:
Clf()
def SaveFormat(root, fmt='eps'):
"""Writes the current figure to a file in the given format.
Args:
root: string filename root
fmt: string format
"""
filename = '%s.%s' % (root, fmt)
print('Writing', filename)
pyplot.savefig(filename, format=fmt, dpi=300)
# provide aliases for calling functons with lower-case names
preplot = PrePlot
subplot = SubPlot
clf = Clf
figure = Figure
plot = Plot
text = Text
scatter = Scatter
pmf = Pmf
pmfs = Pmfs
hist = Hist
hists = Hists
diff = Diff
cdf = Cdf
cdfs = Cdfs
contour = Contour
pcolor = Pcolor
config = Config
show = Show
save = Save
def main():
color_iter = _Brewer.ColorGenerator(7)
for color in color_iter:
print(color)
if __name__ == '__main__':
main()