2017-04-25 2 views
3

Ich habe einige Zeit an Tensorflow gearbeitet, aber eines der Dinge, die ich nicht herausfinden kann, ist, wie man die Kategorische Zielspalte für ein Modell in tf.contrib codiert .learn-Modell.Codierungszielspalte für Klassifizierung in Tensorflow

Ich bin mir bewusst, dass wir eine Eingabefunktion definieren, die unten wie der Code ähnelt:

def input_fn(joined): 
    continuous_cols = {k: tf.constant(joined[k].values) 
        for k in CONTINUOUS_COLUMNS} 

    categorical_cols = {k: tf.SparseTensor(
     indices=[[i, 0] for i in range(joined[k].size)], 
     values=joined[k].values, 
     dense_shape=[joined[k].size, 1]) 
         for k in CATEGORICAL_COLUMNS} 

    # Merges the two dictionaries into one. 
    feature_cols = dict(continuous_cols.items() | categorical_cols.items()) 
    target = tf.constant(joined[target_col].values) 
    return feature_cols, target 

def train_input_fn(): 
    return input_fn(train_frame) 
def test_input_fn(): 
    return input_fn(test_frame) 

Das funktioniert völlig in Ordnung für Binary Klassifizierung oder für Fälle, in denen wir vorher encodieren die Zielvariable entweder mit sklearn derLabelEncoder oder eine andere Methode. Aber wie kann man diese Variable mit tensflow verschlüsseln, damit tf.contrib.learn sie akzeptiert?

Ich habe versucht, den Code für die Zielspalte auf die folgende Veränderung:

target = tf.SparseTensor(
     indices=[[i, 0] for i in range(joined[target_col].size)], 
     values=joined[target_col].values, 
     dense_shape=[joined[target_col].size, 1]) 

Da es sich um eine String-Variable ist, so dass ich dachte, spärlich Tensor es tun sollten, aber dies gibt den Fehler:

ValueError: SparseTensor is not supported. 

Kann mir jemand helfen zu spezifizieren, was ich Platzhalter sollte ich in der Eingabefunktion für das DNNClassifier Modell für die Zielkategorie Variable verwenden.

Antwort

0

DNNClassifier Unterstützung Etiketten als Tensoren keine Sparse. Erwartete Bezeichnung ist etwas wie [[4], [5], [1]], wobei jeder Wert einer ganzzahligen ID entspricht. Wenn Sie Zeichenfolgenbeschriftungen haben, können Sie das Argument label_keys verwenden.

Verwandte Themen