The answer also lies in sklearn’s documentation.
You need to define two things:
-
an estimator that implements the
fit(X, y)
function,X
being the matrix with inputs andy
being the vector of outputs -
a scorer function, or callable object that can be used with:
scorer(estimator, X, y)
and returns the score of given model
Referring to your example: first of all, scorer
shouldn’t be a method of the estimator, it’s a different notion. Just create a callable:
def scorer(estimator, X, y)
return ????? # compute whatever you want, it's up to you to define
# what does it mean that the given estimator is "good" or "bad"
Or even a more simple solution: you can pass a string 'mean_squared_error'
or 'accuracy'
(full list available in this part of the documentation) to cross_val_score
function to use a predefined scorer.
Another possibility is to use make_scorer
factory function.
As for the second thing, you can pass parameters to your model through the fit_params
dict
parameter of the cross_val_score
function (as mentioned in the documentation). These parameters will be passed to the fit
function.
class my_estimator():
def fit(X, y, **kwargs):
alpha = kwargs['alpha']
beta=X[1,:]+alpha
return beta
After reading all the error messages, which provide quite clear idea of what’s missing, here is a simple example:
import numpy as np
from sklearn.cross_validation import cross_val_score
class RegularizedRegressor:
def __init__(self, l = 0.01):
self.l = l
def combine(self, inputs):
return sum([i*w for (i,w) in zip([1] + inputs, self.weights)])
def predict(self, X):
return [self.combine(x) for x in X]
def classify(self, inputs):
return sign(self.predict(inputs))
def fit(self, X, y, **kwargs):
self.l = kwargs['l']
X = np.matrix(X)
y = np.matrix(y)
W = (X.transpose() * X).getI() * X.transpose() * y
self.weights = [w[0] for w in W.tolist()]
def get_params(self, deep = False):
return {'l':self.l}
X = np.matrix([[0, 0], [1, 0], [0, 1], [1, 1]])
y = np.matrix([0, 1, 1, 0]).transpose()
print cross_val_score(RegularizedRegressor(),
X,
y,
fit_params={'l':0.1},
scoring = 'mean_squared_error')