Skip to content

Commit 7a52b0e

Browse files
committed
Get the fields dynamically
1 parent 38483bc commit 7a52b0e

File tree

2 files changed

+7
-15
lines changed

2 files changed

+7
-15
lines changed

sparklingml/feature/python_pipelines.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ class SpacyAdvancedTokenizeTransformer(Model, HasInputCol, HasOutputCol):
142142
>>> str(tr.getLang())
143143
'en'
144144
>>> tr.getSpacyFields()
145-
[u'ancestors', ...
145+
['_', 'ancestors', ...
146146
>>> tr.setSpacyFields(["text", "lang_"])
147147
SpacyAdvancedTokenizeTransformer_...
148148
>>> r = tr.transform(df).head().c

sparklingml/transformation_functions.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
from __future__ import unicode_literals
22

3+
import inspect
4+
import spacy
5+
36
from pyspark.rdd import ignore_unicode_prefix
47
from pyspark.sql.types import *
58

@@ -114,20 +117,9 @@ class SpacyAdvancedTokenize(TransformationFunction):
114117
[(u'a', None), (u'lang', '...'), (u'lower_', 'boo'), (u'text', 'boo')]
115118
"""
116119

117-
default_fields = [
118-
'ancestors', 'check_flag', 'children', 'cluster', 'conjuncts', 'dep',
119-
'ent_id', 'ent_iob', 'ent_type', 'has_repvec', 'has_vector', 'head',
120-
'i', 'idx', 'is_alpha', 'is_ancestor', 'is_ancestor_of', 'is_ascii',
121-
'is_bracket', 'is_digit', 'is_left_punct', 'is_lower', 'is_oov',
122-
'is_punct', 'is_quote', 'is_right_punct', 'is_space', 'is_stop',
123-
'is_title', 'lang', 'lang_', 'left_edge', 'lefts', 'lemma',
124-
'lemma_', 'lex_id', 'like_email', 'like_num', 'like_url',
125-
'lower', 'lower_', 'n_lefts', 'n_rights', 'nbor', 'norm',
126-
'norm_', 'orth', 'orth_', 'pos', 'pos_', 'prefix', 'prefix_',
127-
'prob', 'rank', 'repvec', 'right_edge', 'rights', 'sentiment', 'shape',
128-
'shape_', 'similarity', 'string', 'subtree', 'suffix', 'suffix_',
129-
'tag', 'tag_', 'text', 'text_with_ws', 'vector', 'vector_norm',
130-
'vocab', 'whitespace_']
120+
default_fields = map(
121+
lambda x: x[0],
122+
inspect.getmembers(spacy.tokens.Token, lambda x: "<attribute '" in repr(x)))
131123

132124
@classmethod
133125
def setup(cls, sc, session, *args):

0 commit comments

Comments
 (0)