由于我需要这个函数,但是对于分类数据,我创建了以下函数,更或者说回答了@JoãoDaniel的问题(我只测试了分类预测变量):
trim <- function (x) gsub("^\\s+|\\s+$", "", x)
getVariable <- function (x) sub("(.*?)[[:space:]].*", "\\1", x)
getSimbolo <- function (x) sub("(.*?)[[:space:]](.*?)[[:space:]].*", "\\2", x)
getReglaFinal = function(elemento) {
x = as.data.frame(strsplit(as.character(elemento),";"))
Regla = apply(x,1, trim)
Regla = data.frame(Regla)
indice = as.numeric(rownames(Regla))
variable = apply(Regla,1, getVariable)
simbolo = apply(Regla,1, getSimbolo)
ReglaRaw = data.frame(Regla,indice,variable,simbolo)
cols <- c( 'variable' , 'simbolo' )
ReglaRaw$tipo_corte <- apply( ReglaRaw[ , cols ] ,1 , paste , collapse = "" )
cortes = unique(ReglaRaw$tipo_corte)
ReglaFinal = ""
for(i in 1:length(cortes)){
maximo = max(ReglaRaw$indice[ReglaRaw$tipo_corte==cortes[i]])
tmp = as.character(ReglaRaw$Regla[ReglaRaw$indice==maximo])
if(ReglaFinal==""){
ReglaFinal = tmp
}else{
ReglaFinal = paste(ReglaFinal,tmp,sep="; ",collapse="; ")
}
}
return(ReglaFinal)
}
CtreePathFuncAllCat <- function (ct) {
ResulTable <- data.frame(Node = character(), Path = character())
for(Node in unique(where(ct))){
NonTerminalNodes <- setdiff(1:(Node - 1), unique(where(ct))[unique(where(ct)) < Node])
NodeWeights <- nodes(ct, Node)[[1]]$weights
Path <- NULL
for (i in NonTerminalNodes){
if(any(NodeWeights & nodes(ct, i)[[1]][2][[1]] == 1)) Path <- append(Path, i)
}
Path2 <- SB <- NULL
variablesNombres <- array()
variablesPuntos <- list()
for(i in 1:length(Path)){
n <- nodes(ct, Path[i])[[1]]
if(i == length(Path)) {
nextNodeID = Node
} else {
nextNodeID = Path[i+1]
}
vec_puntos = as.vector(n[[5]]$splitpoint)
vec_nombre = n[[5]]$variableName
vec_niveles = attr(n[[5]]$splitpoint,"levels")
index = 0
if((length(vec_puntos)!=length(vec_niveles)) && (length(vec_niveles)!=0) ){
index = vec_puntos
vec_puntos = vector(length=length(vec_niveles))
vec_puntos[index] = TRUE
}
if(length(vec_niveles)==0){
index = vec_puntos
vec_puntos = n[[5]]$splitpoint
}
if(index==0){
if(nextNodeID==n$right$nodeID){
vec_puntos = !vec_puntos
}else{
vec_puntos = !!vec_puntos
}
if(i != 1) {
for(j in 1:(length(Path)-1)){
if(length(variablesNombres)>=j){
if( variablesNombres[j]==vec_nombre){
vec_puntos = vec_puntos*variablesPuntos[[j]]
}
}
}
vec_puntos = vec_puntos==1
}
SB = "="
}else{
if(nextNodeID==n$right$nodeID){
SB = ">"
}else{
SB = "<="
}
}
variablesPuntos[[i]] = vec_puntos
variablesNombres[i] = vec_nombre
if(length(vec_niveles)==0){
descripcion = vec_puntos
}else{
descripcion = paste(vec_niveles[vec_puntos],collapse=", ")
}
Path2 <- paste(c(Path2, paste(c(variablesNombres[i],SB,"{",descripcion, "}"),collapse=" ")
),
collapse = "; ")
}
ResulTable <- rbind(ResulTable, cbind(Node = Node, Path = Path2))
}
we = weights(ct)
c0 = as.matrix(where(ct))
c3 = sapply(we, function(w) sum(w))
c3 = as.matrix(unique(cbind(c0,c3)))
Counts = as.matrix(c3[,2])
c2 = drop(Predict(ct))
Means = as.matrix(unique(c2))
ResulTable = data.frame(ResulTable,Means,Counts)
ResulTable = ResulTable[ order(ResulTable$Means) ,]
ResulTable$TruePath = apply(as.data.frame(ResulTable$Path),1, getReglaFinal)
ResulTable2 = ResulTable
ResulTable2$SQL <- paste("WHEN ",gsub("\\'([-+]?([0-9]*\\.[0-9]+|[0-9]+))\\'", "\\1",gsub("\\, ", "','", gsub(" \\}", "')", gsub("\\{ ", "('", gsub("\\;", " AND ", ResulTable2$TruePath)))))," THEN ")
cols <- c( 'SQL' , 'Node' )
ResulTable2$SQL <- apply( ResulTable2[ , cols ] ,1 , paste , collapse = "'Nodo " )
ResulTable2$SQL <- gsub("THEN'", "THEN '", gsub(" '", "'", paste(ResulTable2$SQL,"'")))
ResultadoFinal = list()
ResultadoFinal$PreTable = ResulTable
ResultadoFinal$Table = ResulTable
ResultadoFinal$Table$Path = ResultadoFinal$Table$TruePath
ResultadoFinal$Table$TruePath = NULL
ResultadoFinal$SQL = paste(" CASE ",paste(ResulTable2$SQL,sep="",collapse=" ")," END ",collapse="")
return(ResultadoFinal)
}
这里有一个测试:
library(party)
TreeModel1 = ctree(PB~ME+SYMPT+HIST+BSE+DECT, data = mammoexp)
Result2 <- CtreePathFuncAllCat(TreeModel1)
Result2
TreeModel2 = ctree(count~spray, data = InsectSprays)
plot(TreeModel2, type="simple")
Result2 <- CtreePathFuncAllCat(TreeModel2)
Result2
airq <- subset(airquality, !is.na(Ozone))
TreeModel3 <- ctree(Ozone ~ ., data = airq, controls = ctree_control(maxsurrogate = 3))
Result2 <- CtreePathFuncAllCat(TreeModel3)
Result2
更新!现在该功能支持分类和数值变量的混合使用!
library(party)
。这是我在SO上的第一个答案,所以当时有点新手。 - David Arenburg