Comment voulez-vous faire TensorFlow + Keras rapide avec un TFRecord dataset?
Ce qui est un exemple de comment utiliser un TensorFlow TFRecord avec un Keras Modèle et de la tf.session.run() tout en gardant la base de données dans tenseurs w/file d'attente des coureurs?
Ci-dessous est un extrait de code qui fonctionne mais il faut les améliorations suivantes:
- Utiliser le Le modèle de l'API
- spécifier une Entrée()
- Charger un jeu de données à partir d'un TFRecord
- Courir à travers un jeu de données en parallèle (comme avec un queuerunner)
Voici l'extrait de code, il y a plusieurs TODO lignes indiquant ce qui est nécessaire:
from keras.models import Model
import tensorflow as tf
from keras import backend as K
from keras.layers import Dense, Input
from keras.objectives import categorical_crossentropy
from tensorflow.examples.tutorials.mnist import input_data
sess = tf.Session()
K.set_session(sess)
# Can this be done more efficiently than placeholders w/TFRecords?
img = tf.placeholder(tf.float32, shape=(None, 784))
labels = tf.placeholder(tf.float32, shape=(None, 10))
# TODO: Use Input()
x = Dense(128, activation='relu')(img)
x = Dense(128, activation='relu')(x)
preds = Dense(10, activation='softmax')(x)
# TODO: Construct model = Model(input=inputs, output=preds)
loss = tf.reduce_mean(categorical_crossentropy(labels, preds))
# TODO: handle TFRecord data, is it the same?
mnist_data = input_data.read_data_sets('MNIST_data', one_hot=True)
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(loss)
sess.run(tf.global_variables_initializer())
# TODO remove default, add queuerunner
with sess.as_default():
for i in range(1000):
batch = mnist_data.train.next_batch(50)
train_step.run(feed_dict={img: batch[0],
labels: batch[1]})
print(loss.eval(feed_dict={img: mnist_data.test.images,
labels: mnist_data.test.labels}))
Pourquoi cette question est-elle pertinente?
- Pour entraînement de haute performance sans revenir à python
- pas TFRecord de numpy à tenseur des conversions
- Keras feront bientôt partie de tensorflow
- Démontrer comment Keras (Modèle) classes peuvent accepter des tenseurs pour les données d'entrée correctement.
Voici quelques démarreur pour une sémantique problème de segmentation exemple:
- exemple unet Keras modèle unet.py, il arrive à être pour la sémantique de la segmentation.
- Keras + Tensorflow Post De Blog
- Un tentative à l'exécution de l'unet un modèle tf session avec TFRecords et Keras modèle (pas de travail)
- Code pour créer le TFRecords: tf_records.py
- Une tentative d'exécution de l'unet un modèle tf session avec TFRecords et Keras modèle est en densenet_fcn.py (pas de travail)
- github.com/tensorflow/tensorflow/issues/8787 sera pour le travail vers le support de cette fonctionnalité au-delà de la solution fournie dans la accepté de répondre.
- mise à jour de pull request github.com/fchollet/keras/pull/6928
Vous devez vous connecter pour publier un commentaire.
Je n'utilise pas tfrecord dataset format afin de ne pas argumenter sur les avantages et les inconvénients, mais je me suis intéressé à l'extension de Keras à l'appui de la même chose.
github.com/indraforyou/keras_tfrecord est le référentiel. Va vous expliquer brièvement les principaux changements.
data_to_tfrecord
etread_and_decode
ici prend soin de créer tfrecord jeu de données et le chargement de la même. Une attention particulière doit être pour mettre en œuvre lesread_and_decode
sinon vous ferez face à cryptique des erreurs au cours de la formation.Maintenant deux
tf.train.shuffle_batch
et KerasInput
couche retourne tenseur. Mais celui retourné partf.train.shuffle_batch
n'ont pas de métadonnées requises par Keras en interne. Comme il s'avère, tout tenseur peut être facilement transformé en un tenseur avec keras métadonnées en appelantInput
couche avectensor
param.Donc cela prend en charge l'initialisation:
Maintenant avec
x_train_inp
tout keras modèle peut être développé.Permet de dire
train_out
est la sortie du tenseur de votre keras modèle. Vous pouvez facilement écrire une formation personnalisée boucle sur les lignes de:Une des caractéristiques de keras qui le rend si lucrative est son généralisée mécanisme de formation avec les fonctions de rappel.
Mais à l'appui de tfrecords type de formation il y a plusieurs changements qui sont nécessaires dans le
fit
fonctionfeed_dict
Mais tout cela peut être facilement pris en charge par un autre paramètre flag. Ce qui rend les choses déconner sont les keras caractéristiques
sample_weight
etclass_weight
ils sont utilisés pour peser chaque échantillon et de peser chaque classe. Pour cela, danscompile()
keras crée des espaces réservés (ici) et les espaces réservés sont également implicitement créé pour les cibles (ici) qui n'est pas nécessaire dans notre cas, les étiquettes sont déjà nourris par tfrecord lecteurs. Ces espaces réservés besoin d'être nourris en cours de session qui est inutile dans notre cae.Donc prendre en compte ces modifications,
compile_tfrecord
(ici) etfit_tfrecord
(ici) sont l'extension decompile
etfit
et actions à dire 95% du code.Ils peuvent être utilisés de la façon suivante:
Vous êtes les bienvenus pour améliorer le code et pull requests.
Mise à jour 2018-08-29 c'est maintenant pris en charge directement dans keras, voir l'exemple suivant:
https://github.com/keras-team/keras/blob/master/examples/mnist_tfrecord.py
Réponse Originale À Cette Question:
TFRecords sont pris en charge par l'aide d'une perte. Voici les principales lignes de la construction d'un externe de perte:
Voici un exemple pour Keras 2. Il fonctionne après l'application du patch petit #7060:
J'ai aussi travaillé à améliorer le soutien aux TFRecords dans le numéro suivant, et tirez sur demande:
Enfin, il est possible d'utiliser
tf.contrib.learn.Experiment
pour former Keras modèles dans TensorFlow.data_flow_ops.RecordInput
retourne seulement le premier lot, puis Keras pense que cette époque est fait, et redémarrer une autre époque. Je ne peux pas comprendre pourquoi. Je sais que c'est dur pour vous de voir ce qui se passe, mais avez-vous des suggestions sur la façon de débogage? Merci beaucoup. Je suis sûr que tfrecord fichier que je passe est correct (avoir plus de 60k images).