-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpredict_model.py
More file actions
69 lines (54 loc) · 2.37 KB
/
predict_model.py
File metadata and controls
69 lines (54 loc) · 2.37 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import argparse
parser = argparse.ArgumentParser(description='Run CNN training on patches with a few different hyperparameter sets.')
parser.add_argument('-c', '--config', help="JSON with script configuration", default='config.json')
parser.add_argument('-m', '--model', help="input CNN model name (saved in JSON and h5 files)", default='cnn_model')
parser.add_argument('-g', '--gpu', help="Which GPU index", default='0')
args = parser.parse_args()
import os
os.environ['KERAS_BACKEND'] = "tensorflow"
os.environ["CUDA_VISIBLE_DEVICES"]=args.gpu
import tensorflow as tf
import keras
if keras.__version__[0] != '2':
print 'Please use the newest Keras 2.x.x API with the Tensorflow backend'
quit()
keras.backend.set_image_data_format('channels_last')
keras.backend.set_image_dim_ordering('tf')
import numpy as np
np.random.seed(2017) # for reproducibility
from keras.preprocessing.image import ImageDataGenerator
from keras.models import model_from_json
from keras.optimizers import SGD
from keras.utils import np_utils
from keras.callbacks import TensorBoard
from os.path import exists, isfile, join
import json
from utils import read_config, get_patch_size, count_events, shuffle_in_place, RecordHistory
def load_model(name):
with open(name + '_architecture.json') as f:
model = model_from_json(f.read())
model.load_weights(name + '_weights.h5')
return model
############################ configuration ###################################
print 'Reading configuration...'
config = read_config(args.config)
cfg_name = args.model
PATCH_SIZE_W = config['prepare_data_em_track']['patch_size_w']
PATCH_SIZE_D = config['prepare_data_em_track']['patch_size_d']
img_rows, img_cols = PATCH_SIZE_W, PATCH_SIZE_D
input_dir = config['training_on_patches']['validation_dir']
############################## Read data #####################################
from PIL import Image
def load_images( filename ):
return np.array( Image.open( filename ) )
classes = [ 'track', 'shower', 'michel', 'none' ]
imagedb = {}
for c in classes:
print c
images = [ load_images( input_dir+c+'/'+file ) for file in os.listdir( input_dir+c+'/' ) ]
imagedb[c] = images #save in dictionary
print imagedb['shower'][1]
################################ Load model #################################
#print 'Import CNN model...'
#with tf.device('/gpu:' + args.gpu):
# model = load_model(cfg_name)