Last active
September 21, 2020 17:10
-
-
Save thmavri/8eb7eeda0d6491c54cfc1b8cc0cb5c00 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#determine the labels | |
import pyvw #vw python interface | |
DEST = 1 | |
PROP = 2 | |
FAC = 3 | |
... | |
#create the class for the Sequence Labeler | |
class SequenceLabeler(pyvw.SearchTask): | |
def __init__(self, vw, sch, num_actions): | |
# you must must must initialize the parent class | |
# this will automatically store self.sch <- sch, self.vw <- vw | |
pyvw.SearchTask.__init__(self, vw, sch, num_actions) | |
# set whatever options you want | |
sch.set_options( sch.AUTO_HAMMING_LOSS | sch.AUTO_CONDITION_FEATURES ) | |
def _run(self, sentence): # it's called _run to remind you that you shouldn't call it directly! | |
output = [] | |
for n in range(len(sentence)): | |
pos,word = sentence[n] | |
# use "with...as..." to guarantee that the example is finished properly | |
with self.vw.example({'w': [word]}) as ex: | |
pred = self.sch.predict(examples=ex, my_tag=n+1, oracle=pos, condition=[(n,'p'), (n-1, 'q')]) | |
output.append(pred) | |
return output | |
... | |
#build the training set | |
... | |
... | |
... | |
my_dataset #training set | |
test_set | |
#train commands | |
vw = pyvw.vw("--search 3 --search_task hook --ring_size 1024") # 3 is the number of labels | |
sequenceLabeler = vw.init_search_task(SequenceLabeler) | |
#actual training | |
for i in xrange(2): | |
sequenceLabeler.learn(my_dataset) | |
#predict | |
test_example = [ (0,w) for w in "hotel amsterdam wifi".split() ] | |
print test_example | |
#[(0, 'hotel'), (0, 'amsterdam'), (0, 'wifi')] | |
out = sequenceLabeler.predict(test_example) | |
print out | |
[2, 1, 3] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi Themis
Have you put up complete code somewhere ?
We have been trying learning2search on ATIS dataset and facing a wired problem
We are trying to use Learning2Search from vowpal-wabbit for Named Entity Recognition on the ATIS dataset. We are using the Python version from pyvw.
In ATIS there are 127 Entities (including an 'Others' category). The training set has 4978 and test has 893 sentences.
Sample sentence: 'i want to fly from boston at DIGITDIGITDIGIT am and arrive in denver at DIGITDIGITDIGITDIGIT in the morning'.
Sample labels: 'O O O O O B-fromloc.city_name O B-depart_time.time I-depart_time.time O O O B-toloc.city_name O B-arrive_time.time O O B-arrive_time.period_of_day'.
Labels are mapped to numbers from 0 - 126.
However, when we train a SequenceLabeler object on the training set and use the predict() method on sentences from the test set, we only see labels 1 & 2 as outputs, which doesn't make sense (see attached image). We tried using this on a different dataset, and had similar results.
Here is a link to our IPython Notebook with the entire flow
Any pointers would be greatly appreciated - we're clearly doing something incorrect, but we're not sure what.
Thanks!