EDIT: The freeze_graph.py script, which is part of the TensorFlow repository, now serves as a tool that generates a protocol buffer representing a “frozen” trained model, from an existing TensorFlow GraphDef and a saved checkpoint. It uses the same steps as described below, but it much easier to use.
Currently the process isn’t very well documented (and subject to refinement), but the approximate steps are as follows:
- Build and train your model as a
tf.Graphcalledg_1. - Fetch the final values of each of the variables and store them as numpy arrays (using
Session.run()). - In a new
tf.Graphcalledg_2, createtf.constant()tensors for each of the variables, using the value of the corresponding numpy array fetched in step 2. -
Use
tf.import_graph_def()to copy nodes fromg_1intog_2, and use theinput_mapargument to replace each variable ing_1with the correspondingtf.constant()tensors created in step 3. You may also want to useinput_mapto specify a new input tensor (e.g. replacing an input pipeline with atf.placeholder()). Use thereturn_elementsargument to specify the name of the predicted output tensor. -
Call
g_2.as_graph_def()to get a protocol buffer representation of the graph.
(NOTE: The generated graph will have extra nodes in the graph for training. Although it is not part of the public API, you may wish to use the internal graph_util.extract_sub_graph() function to strip these nodes from the graph.)