2017-05-23 7 views
1

Ich sehe aus dem tutorial, dass wir dies tun können:Zugriff auf Werte in Protos in TensorFlow?

for node in tf.get_default_graph().as_graph_def().node: print node

Wenn auf einem beliebigen Netzwerk getan, wir haben viele Schlüsselwertepaare erhalten. Beispiel:

name: "conv2d_2/convolution" 
op: "Conv2D" 
input: "max_pooling2d/MaxPool" 
input: "conv2d_1/kernel/read" 
device: "/device:GPU:0" 
attr { 
    key: "T" 
    value { 
    type: DT_FLOAT 
    } 
} 
attr { 
    key: "data_format" 
    value { 
    s: "NHWC" 
    } 
} 
attr { 
    key: "padding" 
    value { 
    s: "SAME" 
    } 
} 
attr { 
    key: "strides" 
    value { 
    list { 
     i: 1 
     i: 1 
     i: 1 
     i: 1 
    } 
    } 
} 
attr { 
    key: "use_cudnn_on_gpu" 
    value { 
    b: true 
    } 
} 

Wie kann ich auf alle diese Werte zugreifen und sie in Python-Listen einfügen? Wie können wir das "strides" -Attribut erhalten und alle 1en in [1, 1, 1, 1] konvertieren?

Antwort

1

TLDR: Der folgende Code ist, was Sie vielleicht verwenden möchten:

for n in tf.get_default_graph().as_graph_def().node: if 'strides' in n.attr.keys(): print n.name, [int(a) for a in n.attr['strides'].list.i] if 'shape' in n.attr.keys(): print n.name, [int(a.size) for a in n.attr['shape'].shape.dim]

Der Trick, um dies zu tun, ist zu verstehen, was protobufs sind. Lassen Sie uns durch die oben genannten tutorial gehen.

Zunächst einmal gibt es eine Erklärung:

for node in graph_def.node

Jeder Knoten ist ein NodeDef Objekt, definiert in tensorflow/core/Rahmen/node_def.proto. Dies sind die grundlegenden Bausteine ​​von TensorFlow-Diagrammen, wobei jeder einen einzelnen -Vorgang mit seinen Eingangsverbindungen definiert. Hier sind die Mitglieder eines NodeDef, und was sie bedeuten.

Hinweis nach dem in node_def.proto:

  • Es importiert attr_value.proto.
  • Es gibt Attribute wie Name, Op, Eingabe, Gerät, attr. Insbesondere gibt es einen repeated Begriff vor dem Eingang. Wir können das jetzt ignorieren.

Dieses genau wie eine Python-Klasse funktioniert und wir können somit node.name, node.op nennen, node.input, node.device, node.attr usw.

Was wir zugreifen möchten jetzt wäre der Inhalt in node.attr. Wenn wir erneut auf das Tutorial verweisen, gibt es Folgendes an:

Dies ist ein Schlüssel/Wert-Speicher, der alle Attribute eines Knotens enthält. Diese sind die permanenten Eigenschaften von Knoten, Dinge, die sich bei Laufzeit nicht ändern, wie die Größe von Filtern für Faltungen oder die Werte von Konstante Ops. Weil es so viele verschiedene Typen von Attributwerten geben kann, von Zeichenfolgen bis zu Arrays von Tensorwerten, gibt es eine separate protobuf-Datei, die die Datenstruktur definiert, die sie in Tensorfluss/core/framework/attr_value.proto enthält .

Jedes Attribut hat eine eindeutige Namenszeichenfolge, und die erwarteten Attribute werden aufgelistet, wenn die Operation definiert ist.Wenn ein Attribut in einem Knoten nicht ist, aber in der Operation eine Standarddefinition vorhanden ist, wird dieser Standardwert verwendet, wenn das Diagramm erstellt wird.

Sie können auf alle diese Member zugreifen, indem Sie node.name, node.op, usw. in Python aufrufen. Die Liste der in GraphDef gespeicherten Knoten ist eine vollständige Definition der Modellarchitektur.

Da dies ein Schlüssel-Wert-Speicher wir n.attr.keys(), um eine Liste der Schlüssel dieses Attribut aufrufen können. Wir können weiter gehen, um vielleicht n.attr['strides'] aufzurufen, um auf die Schritte zuzugreifen, wenn ein solcher Schlüssel verfügbar ist. Wenn wir versuchen, diese zu drucken, erhalten wir folgendes:

list { 
    i: 1 
    i: 2 
    i: 2 
    i: 1 
} 

Und das ist, wo es beginnt verwirrend zu bekommen, weil wir versuchen könnten, list(n.attr['strides']) oder etwas dieser Art zu tun. Wenn wir uns attr_value.proto ansehen, können wir verstehen, was vor sich geht. Wir sehen, dass es oneof value ist und in diesem Fall ist es ein ListValue list, so dass wir n.attr['strides'].list anrufen können. Und wenn wir diese zu drucken, erhalten wir folgendes:

i: 1 
i: 1 
i: 1 
i: 1 

Wir könnten nächstes versuchen, dies zu tun: [a for a in n.attr['strides'].list] oder [a.i for a in n.attr['strides'].list]. Nichts funktioniert jedoch. Dies ist, wo repeated ist ein wichtiger Begriff zu verstehen. Es bedeutet im Grunde, dass es eine Int64-Liste gibt, und Sie müssen darauf mit dem Attribut i zugreifen. Doing [int(a) for a in n.attr['strides'].list.i] gibt uns dann, was wir wollen, eine Python-Liste, die wir verwenden können:

[1, 1, 1, 1]