mirror of
https://github.com/donnemartin/data-science-ipython-notebooks.git
synced 2024-03-22 13:30:56 +08:00
44 lines
1.2 KiB
Python
44 lines
1.2 KiB
Python
import numpy as np
|
|
import json
|
|
|
|
from keras.utils.data_utils import get_file
|
|
from keras import backend as K
|
|
|
|
CLASS_INDEX = None
|
|
CLASS_INDEX_PATH = 'https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json'
|
|
|
|
|
|
def preprocess_input(x, dim_ordering='default'):
|
|
if dim_ordering == 'default':
|
|
dim_ordering = K.image_dim_ordering()
|
|
assert dim_ordering in {'tf', 'th'}
|
|
|
|
if dim_ordering == 'th':
|
|
x[:, 0, :, :] -= 103.939
|
|
x[:, 1, :, :] -= 116.779
|
|
x[:, 2, :, :] -= 123.68
|
|
# 'RGB'->'BGR'
|
|
x = x[:, ::-1, :, :]
|
|
else:
|
|
x[:, :, :, 0] -= 103.939
|
|
x[:, :, :, 1] -= 116.779
|
|
x[:, :, :, 2] -= 123.68
|
|
# 'RGB'->'BGR'
|
|
x = x[:, :, :, ::-1]
|
|
return x
|
|
|
|
|
|
def decode_predictions(preds):
|
|
global CLASS_INDEX
|
|
assert len(preds.shape) == 2 and preds.shape[1] == 1000
|
|
if CLASS_INDEX is None:
|
|
fpath = get_file('imagenet_class_index.json',
|
|
CLASS_INDEX_PATH,
|
|
cache_subdir='models')
|
|
CLASS_INDEX = json.load(open(fpath))
|
|
indices = np.argmax(preds, axis=-1)
|
|
results = []
|
|
for i in indices:
|
|
results.append(CLASS_INDEX[str(i)])
|
|
return results
|