'''
See LICENCE_BSD for licensing information

@author: Steven Davies 2012

'''

from Plugin import run
from sklearn import naive_bayes
from SklearnEvaluator import SklearnEvaluator

class NbEvaluator(SklearnEvaluator):
    def __init__(self):
        SklearnEvaluator.__init__(self)

    def create_model(self):
        return naive_bayes.GaussianNB()

    def get_model_details(self, model):
        for i, c in enumerate(model.theta_[1]):
            yield (self.evaluations[i] + ' mean', c)
        for i, c in enumerate(model.sigma_[1]):
            yield (self.evaluations[i] + ' variance', c)

    def predict(self, model, testing):
        return model.predict_proba(testing)[...,1]


if __name__ == '__main__': # pragma: no cover
    run(NbEvaluator())
