2017-06-21 3 views
1

Gibt es eine Möglichkeit, einen in R eingebauten Zufallsforest in einen SAS-Code zu konvertieren, ohne alle if-then-Typen austippen zu müssen, die getTree liefert?R Randomforest-Paket zu Base SAS

hatte ich 30 Bäume, die sich um die

+0

Was haben Sie bisher versucht? – Stedy

+0

Codeumwandlungsfragen sind bei Stack Overflow nicht Thema. Machen Sie die Recherche und machen Sie sich die Mühe, und fragen Sie dann nach dem resultierenden SAS-Programm. – Joe

+0

Hängt von den Regeln und ihrem Format ab. Da die meisten Leute mit R oder SAS vertraut sind, sollten Sie sich die Zeit nehmen, einige Beispiele zu zeigen und die Frage detailliert zu formulieren, wenn Sie eine Antwort wünschen. – Reeza

Antwort

0

Das ist etwas, was ich hatte, liegen 1900 Zeilen in der getTree Funktion hatten, sollten Sie Hilfe gestartet. Bisher ist nur Regression wird unterstützt, aber Klassifikation tun können sollte mit ein wenig zusätzliche Arbeit:

/* R code for exporting the randomForest object */ 
#Output dataset to csv for validation in SAS 
write.csv(iris,file="C:/temp/iris.csv",row.names=FALSE) 

#Train a 2-tree random forest for testing purposes 
require(randomForest) 
rf2 <- randomForest(iris[,-1],iris[,1],ntree=2) 

# Get predictions and write to csv 
write.csv(predict(rf2,iris),file="c:/temp/pred_rf2b.csv") 

# Export factor levels 
mydata <- iris 
type <- sapply(mydata,class) 
factors = type[type=="factor"] 
output <- lapply(names(factors),function(x){ 
    res <- data.frame(VarName=x, 
        Level=levels(mydata[,x]), 
        Number=1:nlevels(mydata[,x])) 
    return(res) 
}) 

write.csv(do.call(rbind, output),file="c:/temp/factorlevels.csv", row.names=FALSE) 

# Export all trees in one file 
treeoutput <- lapply(1:rf2$ntree,function(x){ 
    res <- getTree(rf2, x, labelVar=TRUE) 
    res$node <- seq.int(nrow(res)) 
    res$treenum <- x 
    return(res) 
}) 

write.csv(do.call(rbind, treeoutput),file="c:/temp/treeexport.csv", row.names=FALSE) 
/*End of R code*/ 

/*Import into SAS, replacing . with _ so we have usable variable names*/ 

proc import 
    datafile = "c:\temp\treeexport.csv" 
    out = tree 
    dbms = csv 
    replace; 
    getnames = yes; 
run; 

data tree; 
set tree; 
SPLIT_VAR = translate(SPLIT_VAR,'_','.'); 
format SPLIT_POINT 8.3; 
run; 

proc import 
    datafile = "c:\temp\factorlevels.csv" 
    out = factorlevels 
    dbms = csv 
    replace; 
    getnames = yes; 
run; 

data _null_; 
    infile "c:\temp\iris.csv"; 
    file "c:\temp\iris2.csv"; 
    input; 
    if _n_ = 1 then _infile_=translate(_infile_,'_','.'); 
    put _infile_; 
run; 

proc import 
    datafile = "c:\temp\iris2.csv" 
    out = iris 
    dbms = csv 
    replace; 
    getnames = yes; 
run; 


data _null_; 
    debug = 0; 
    type = "regression"; 
    maxiterations = 10000; 
    file log; 
    if 0 then set tree factorlevels; 
    /*Hash to hold the whole tree*/ 
    declare hash t(dataset:'tree'); 
    rc = t.definekey('treenum'); 
    rc = t.definekey('node'); 
    rc = t.definedata(all:'yes'); 
    rc = t.definedone(); 

    /*Hash for looking up factor levels*/ 
    declare hash fl(dataset:'factorlevels'); 
    rc = fl.definekey('VARNAME','NUMBER'); 
    rc = fl.definedata('LEVEL'); 
    rc = fl.definedone(); 

    do treenum = 1 by 1 while(t.find(key:treenum,key:1)=0); 
    /*Hash to hold the queue for current tree*/ 
    length position qnode processed 8; 
    declare hash q(ordered:'a'); 
    rc = q.definekey('position'); 
    rc = q.definedata('qnode','position','processed'); 
    rc = q.definedone(); 
    declare hiter qi('q'); 
    /*Hash for reverse queue lookup*/ 
    declare hash q2(); 
    rc = q2.definekey('qnode'); 
    rc = q2.definedata('position'); 
    rc = q2.definedone(); 

    /*Load the starting node for the current tree*/ 
    node = 1; 
    nodetype = 'L'; /*Track whether current node is a Left or Right node*/ 
    complete = 0; 
    length treename $10; 
    treename = cats('tree',treenum); 

    do iteration = 1 by 1 while(complete = 0 and iteration <= maxiterations); 
     rc = t.find(); 
     if debug then put "Processing node " node; 

     /*Logic for terminal nodes*/ 
     if status = -1 then do; 
     if type ne "regression" then prediction = cats('"',prediction,'"'); 
     put treename '=' prediction ';'; 
     /*If current node is a right node, remove it from the queue*/ 
     if nodetype = 'R' then do; 
      rc = q2.find(); 
      if debug then put "Unqueueing node " qnode "in position " position; 
      processed = 1; 
      rc = q.replace(); 
     end; 
     /*If the queue is empty, we are done*/ 
     rc = qi.last(); 
     do while(rc = 0 and processed = 1); 
      if position = 1 then complete = 1; 
      rc = qi.prev(); 
     end; 
     /*Otherwise, process the most recently queued unprocessed node*/ 
     if complete = 0 then do; 
      put "else "; 
      node = qnode; 
      nodetype = 'R'; 
     end; 
     end; 

     /*Logic for split nodes - status ne -1*/ 
     else do; 
     /*Add right_daughter to queue if present*/ 
     position = q.num_items + 1; 
     qnode = right_daughter; 
     processed = 0; 
     rc = q.add(); 
     rc = q2.add(); 
     if debug then put "Queueing node " qnode "in position " position; 

     /*Check whether current split var is a (categorical) factor*/ 
     rc = fl.find(key:split_var,key:1); 
     /*If yes, factor levels corresponding to 1s in the binary representation of the split point go left*/ 
     if rc = 0 then do; 
      /*Get binary representation of split point (least significant bit first)*/ 
      /*binaryw. format behaves very differently above width 58 - only 58 levels per factor supported here*/ 
      /*This is sufficient as the R randomForest package only supports 53 levels per factor anyway*/ 
      binarysplit = reverse(put(split_point,binary58.)); 
      put 'if ' @; 
      j=0; /*Track how many levels have been encountered for this split var*/ 
      do i = 1 to 64 while(rc = 0); 
      if i > 1 then rc = fl.find(key:split_var,key:i); 
      LEVEL = cats('"',LEVEL,'"'); 
      if debug then put _all_; 
      if substr(binarysplit,i,1) = '1' then do; 
       if j > 0 then put ' or ' @; 
       put split_var ' = ' LEVEL @; 
       j + 1; 
      end; 
      end; 
      put 'then'; 
     end; 
     /*If not, anything < split point goes to left child*/ 
     else put "if " split_var "< " split_point 8.3 " then "; 
     if nodetype = 'R' then do; 
      qnode = node; 
      rc = q2.find(); 
      if debug then put "Unqueueing node " qnode "in position " position; 
      processed = 1; 
      rc = q.replace(); 
     end; 
     node = left_daughter; 
     nodetype = 'L'; 
     end; 
    end; 
    /*End of tree function definition!*/ 
    put ';'; 
    /*Clear the queue between trees*/ 
    rc = q.delete(); 
    rc = q2.delete(); 
    end; 

    /*We end up going 1 past the actual number of trees after the end of the do loop*/ 
    treenum = treenum - 1; 

    if type = "regression" then do; 
    put 'RFprediction=('; 
    do i = 1 to treenum; 
     treename = cats('tree',i); 
     put treename +1 @; 
     if i < treenum then put '+' +1 @; 
    end; 
    put ')/' treenum ';'; 
    end; 

    /*To do - write code to aggregate predictions from multiple trees for classification*/ 

    stop; 
run; 


/*Sample of generated if-then-else code */ 

data predictions; 
    set iris; 
if Petal_Length < 4.150 then 
if Petal_Width < 1.050 then 
if Petal_Width < 0.350 then 
tree1 =4.91702127659574 ; 
else 
if Petal_Width < 0.450 then 
tree1 =5.18333333333333 ; 
else 
if Species = "versicolor" then 
tree1 =5.08888888888889 ; 
else 
tree1 =5.1 ; 
else 
if Sepal_Width < 2.550 then 
tree1 =5.525 ; 
else 
if Petal_Length < 4.050 then 
tree1 =5.8 ; 
else 
tree1 =5.63333333333333 ; 
else 
if Petal_Width < 1.950 then 
if Sepal_Width < 3.050 then 
if Species = "setosa" or Species = "virginica" then 
if Petal_Length < 5.700 then 
tree1 =6.05833333333333 ; 
else 
tree1 =7.2 ; 
else 
tree1 =6.176 ; 
else 
if Sepal_Width < 3.250 then 
if Sepal_Width < 3.150 then 
tree1 =6.62 ; 
else 
tree1 =6.66666666666667 ; 
else 
tree1 =6.3 ; 
else 
if Petal_Length < 6.050 then 
if Petal_Width < 2.050 then 
tree1 =6.275 ; 
else 
tree1 =6.65 ; 
else 
if Petal_Length < 6.550 then 
tree1 =7.76666666666667 ; 
else 
tree1 =7.7 ; 
; 
if Petal_Width < 1.150 then 
if Species = "setosa" then 
tree2 =5.08947368421053 ; 
else 
tree2 =5.55714285714286 ; 
else 
if Species = "setosa" or Species = "versicolor" then 
if Sepal_Width < 2.750 then 
if Petal_Length < 4.450 then 
tree2 =5.44 ; 
else 
tree2 =6.06666666666667 ; 
else 
if Petal_Width < 1.350 then 
tree2 =5.85294117647059 ; 
else 
if Petal_Width < 1.750 then 
if Petal_Width < 1.650 then 
tree2 =6.3625 ; 
else 
tree2 =6.7 ; 
else 
tree2 =5.9 ; 
else 
if Petal_Length < 5.850 then 
if Sepal_Width < 2.650 then 
if Petal_Length < 4.750 then 
tree2 =4.9 ; 
else 
if Sepal_Width < 2.350 then 
tree2 =6 ; 
else 
if Sepal_Width < 2.550 then 
tree2 =6.14 ; 
else 
tree2 =6.1 ; 
else 
tree2 =6.49166666666667 ; 
else 
if Petal_Length < 6.350 then 
tree2 =7.125 ; 
else 
tree2 =7.775 ; 
; 
RFprediction=(
tree1 + tree2 )/2 ; 
run;