如何在R中使用网格绘制非线性决策边界?

4
我在ISLR图2.13)或ESL中发现了这张特别好的图。我无法猜出作者们是如何用R制作这张图的。我知道如何很容易地获得橙色和蓝色点。主要困惑在于背景点和紫色线。
有什么想法吗?图2.13,ISLR 这里有一些示例代码,用灰色网格获取黄色和橙色点。如何获得任意非线性曲线,并根据曲线着色网格?
set.seed(pi)
points = replicate(100, runif(2))
pointsColored = ifelse(apply(points, 2, sum) <= 1, "orange", "blue")
# Confound some
pointsColored[sample.int(length(pointsColored), 10)] = "orange"
plot(x=points[1, ], y=points[2, ])
grid(nx=100, ny=100)
# Plot points over the grid.
points(x=points[1, ], y=points[2, ], col=pointsColored)

哦...被踩了?因为我没有提供自己的尝试吗? - asb
请提供您在样本数据上尝试过的内容... - vrajs5
提供一些上下文会更好,例如ISLR中图2.13周围的文本。这是一个具有两个类别的二维分类问题。圆圈是训练数据。点显示了基于某些机器学习算法的贝叶斯分类器的预测。虚线是贝叶斯决策边界。主要问题实际上是从您正在使用的任何ML算法中提取决策边界。使用类似于“dots <- expand.grid(...)”和“plot(..., col = predict(dots))”的方法很容易得到预测的点。 - Stephan Kolassa
请查看stats.stackexchange.com上的帖子。搜索“非线性决策边界”可能会得到更多结果,但乍一看,这个帖子似乎正是你想要的。 - jbaums
1
这篇文章的标题应该改为描述所讨论的图表类型。互联网搜索引擎似乎强调标题。现有的标题对于未来试图寻找创建类似图表帮助的人没有任何价值。 - Mark Miller
显示剩余4条评论
2个回答

6

如我在评论中所示,这里提供了一个解决方案,由@chl在stats.stackexchange.com上提供。这是应用于您的数据集的解决方案。

library(class)
set.seed(pi)
X <- t(replicate(1000, runif(2)))
g <- ifelse(apply(X, 1, sum) <= 1, 0, 1)
xnew <- cbind(rep(seq(0, 1, length.out=50), 50),
              rep(seq(0, 1, length.out=50), each=50))
m <- knn(X, xnew, g, k=15, prob=TRUE)
prob <- attr(m, "prob")
prob <- ifelse(m=="1", prob, 1-prob)
prob15 <- matrix(prob, 50)
par(mar=rep(3, 4))
contour(unique(xnew[, 1]), unique(xnew[, 2]), prob15, levels=0.5, 
        labels="", xlab='', ylab='', axes=FALSE, lwd=2.5, asp=1)
title(xlab=expression(italic('X')[1]), ylab=expression(italic('X')[2]), 
      line=1, family='serif', cex.lab=1.5)
points(X, bg=ifelse(g==1, "#CA002070", "#0571B070"), pch=21)
gd <- expand.grid(x=unique(xnew[, 1]), y=unique(xnew[, 2]))
points(gd, pch=20, cex=0.4, col=ifelse(prob15 > 0.5, "#CA0020", "#0571B0"))
box()

决策边界

(更新:我更改了颜色调色板,因为蓝色/黄色/紫色的组合相当难看。)


感谢您提供另一个答案的链接。 - asb

1
这只是我愚蠢的近似尝试。显然,@StephenKolassa提出的问题是有效的,而且这个近似方法没有处理好。
myCurve1 = function (x)
  abs(x[[1]] * sin(x[[1]]) + x[[2]] * sin(x[[2]]))
myCurve2 = function (x)
  abs(x[[1]] * cos(x[[1]]) + x[[2]] * cos(x[[2]]))
myCurve3 = function (x)
  abs(x[[1]] * tan(x[[1]]) + x[[2]] * tan(x[[2]]))

tmp = function (myCurve, seed=99) {
  set.seed(seed)
  points = replicate(100, runif(2))
  colors = ifelse(apply(points, 2, myCurve) > 0.5, "orange", "blue")
  # Confound some
  swapInts = sample.int(length(colors), 6)
  for (i in swapInts) {
    if (colors[[i]] == "orange") {
      colors[[i]] = "blue"
    } else {
      colors[[i]] = "orange"
    }
  }
  gridPoints = seq(0, 1, 0.005)
  gridPoints = as.matrix(expand.grid(gridPoints, gridPoints))
  gridColors = vector("character", nrow(gridPoints))
  gridPch = vector("character", nrow(gridPoints))
  for (i in 1:nrow(gridPoints)) {
    val = myCurve(gridPoints[i, ])
    if (val > 0.505) {
      gridColors[[i]] = "orange"
      gridPch[[i]] = "."
    } else if (val < 0.495) {
      gridColors[[i]] = "blue"
      gridPch[[i]] = "."
    } else {
      gridColors[[i]] = "purple"
      gridPch[[i]] = "*"
    }
  }
  plot(x=gridPoints[ , 1], y=gridPoints[ , 2], col=gridColors, pch=gridPch)
  points(x=points[1, ], y=points[2, ], col=colors, lwd=2)
}

par(mfrow=c(1, 3))
tmp(myCurve1)
tmp(myCurve2)
tmp(myCurve3)

enter image description here


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接