scikit learn - How to correctly override and call super-method in Python -
first, problem @ hand. writing wrapper scikit-learn
class, , having problems right syntax. trying achieve override of fit_transform
function, alters input slightly, , calls super
-method new parameters:
from sklearn.feature_extraction.text import tfidfvectorizer class tidfvectorizerwrapper(tfidfvectorizer): def __init__(self): tfidfvectorizer.__init__(self) # necessary? def fit_transform(self, x, y=none, **fit_params): x = [content.split('\t')[0] content in x] # filtering input return tfidfvectorizer.fit_transform(self, x, y, fit_params) # critical part, ide tells me # fit_params: 'unexpected arguments'
the program crashes on place, starting multiprocessing exception
, not telling me usefull. how correctly this?
additional info: reason why need wrap way because use sklearn.pipeline.featureunion
collect feature extractors before putting them sklearn.pipeline.pipeline
. consequence of doing way is, can feed single data set across feature extractors -- different extractors need different data. solution feed data in separable format , filtering different parts in different extractors. if there better solution problem, i'd happy hear it.
edit 1: adding **
unpack dict seems not change anything:
edit 2: solved remaining problem -- needed remove constructor overload. apparently, trying call parent constructor, wishing have instance variables initiated correctly, did exact opposite. wrapper had no idea kind of parameters can expect. once removed superfluous call, worked out perfectly.
you forget unpack fit_params
passed dict
, want pass through keyword arguments
require unpacking operator **
.
from sklearn.feature_extraction.text import tfidfvectorizer class tidfvectorizerwrapper(tfidfvectorizer): def fit_transform(self, x, y=none, **fit_params): x = [content.split('\t')[0] content in x] # filtering input return tfidfvectorizer.fit_transform(self, x, y, **fit_params)
one other thing instaed of calling tfidfvectorizer
's fit_transform
directly can call overloaded version through super
method
from sklearn.feature_extraction.text import tfidfvectorizer class tidfvectorizerwrapper(tfidfvectorizer): def fit_transform(self, x, y=none, **fit_params): x = [content.split('\t')[0] content in x] # filtering input return super(tidfvectorizerwrapper, self).fit_transform(x, y, **fit_params)
to understand check following example
def foo1(**kargs): print kargs def foo2(**kargs): foo1(**kargs) print 'foo2' def foo3(**kargs): foo1(kargs) print 'foo3' foo1(a=1, b=2)
it prints dictionary {'a': 1, 'b': 2}
foo2(a=1, b=2)
prints both dictionary , foo2
, but
foo3(a=1, b=2)
raises error sent positional argument equal our dictionary foo1
, not accept such thing. do
def foo4(**kargs): foo1(x=kargs) print 'foo4'
which works fine, prints new dictionary {'x': {'a': 1, 'b': 2}}
Comments
Post a Comment