DANN

Package Introduction

DANN is a variation of k nearest neighbors where the shape of the neighborhood takes into account training data’s class. The neighborhood is elongated along class boundaries and shrunk in the orthogonal direction to class boundaries. See Discriminate Adaptive Nearest Neighbor Classification by Hastie and Tibshirani. This package implements DANN and sub-DANN in section 4.1 of the publication and is based on Christopher Jenness’s python implementation.

Arguments

Example: Clustered Data

In this example, a simulated data set is made. There is some overlap between classes.

library(dann)
library(dplyr, warn.conflicts = FALSE)
library(ggplot2)
library(mlbench)

set.seed(1)
train <- mlbench.2dnormals(600, cl = 6, r = sqrt(2), sd = .5) %>%
  tibble::as_tibble()
colnames(train) <- c("X1", "X2", "Y")

ggplot(train, aes(x = X1, y = X2, colour = Y)) + 
  geom_point() + 
  labs(title = "Train Data")



test <- mlbench.2dnormals(600, cl = 6, r = sqrt(2), sd = .5) %>%
  tibble::as_tibble()
colnames(test) <- c("X1", "X2", "Y")
ggplot(test, aes(x = X1, y = X2, colour = Y)) + 
  geom_point() + 
  labs(title = "Test Data")

To work with the dann package, data needs to be in matrices instead of dataframes.

xTrain <- train %>%
  select(X1, X2) %>%
  as.matrix()
yTrain <- train %>%
  pull(Y) %>%
  as.numeric() %>%
  as.vector()

xTest <- test %>%
  select(X1, X2) %>%
  as.matrix()
yTest <- test %>%
  pull(Y) %>%
  as.numeric() %>%
  as.vector()

To train a model, the matrices and a few parameters are passed into dann. The argument neighborhood_size is the number of data points used to estimate a good shape of the neighborhood. The argument k is the number of data points used in the final classification. Considering there is overlap between all classes and there are only about 100 data points per class, dann performs well for this data set.

dannPreds <- dann(xTrain = xTrain, yTrain = yTrain, xTest = xTest,
                  k = 7, neighborhood_size = 150, epsilon = 1)
round(mean(dannPreds == yTest), 2)
#> [1] 0.8