@@ -234,38 +234,39 @@ PYBIND11_MODULE(flashlight_lib_text_decoder, m) {
234234 .def_readwrite (" sil_score" , &LexiconDecoderOptions::silScore)
235235 .def_readwrite (" log_add" , &LexiconDecoderOptions::logAdd)
236236 .def_readwrite (" criterion_type" , &LexiconDecoderOptions::criterionType)
237- .def (py::pickle (
238- [](const LexiconDecoderOptions& p) { // __getstate__
239- return py::make_tuple (
240- p.beamSize ,
241- p.beamSizeToken ,
242- p.beamThreshold ,
243- p.lmWeight ,
244- p.wordScore ,
245- p.unkScore ,
246- p.silScore ,
247- p.logAdd ,
248- p.criterionType );
249- },
250- [](py::tuple t) { // __setstate__
251- if (t.size () != 9 ) {
252- throw std::runtime_error (
253- " Cannot run __setstate__ on LexiconDecoderOptions - "
254- " insufficient arguments provided." );
255- }
256- LexiconDecoderOptions opts = {
257- t[0 ].cast <int >(), // beamSize
258- t[1 ].cast <int >(), // beamSizeToken
259- t[2 ].cast <double >(), // beamThreshold
260- t[3 ].cast <double >(), // lmWeight
261- t[4 ].cast <double >(), // wordScore
262- t[5 ].cast <double >(), // unkScore
263- t[6 ].cast <double >(), // silScore
264- t[7 ].cast <bool >(), // logAdd
265- t[8 ].cast <CriterionType>() // criterionType
266- };
267- return opts;
268- }));
237+ .def (
238+ py::pickle (
239+ [](const LexiconDecoderOptions& p) { // __getstate__
240+ return py::make_tuple (
241+ p.beamSize ,
242+ p.beamSizeToken ,
243+ p.beamThreshold ,
244+ p.lmWeight ,
245+ p.wordScore ,
246+ p.unkScore ,
247+ p.silScore ,
248+ p.logAdd ,
249+ p.criterionType );
250+ },
251+ [](py::tuple t) { // __setstate__
252+ if (t.size () != 9 ) {
253+ throw std::runtime_error (
254+ " Cannot run __setstate__ on LexiconDecoderOptions - "
255+ " insufficient arguments provided." );
256+ }
257+ LexiconDecoderOptions opts = {
258+ t[0 ].cast <int >(), // beamSize
259+ t[1 ].cast <int >(), // beamSizeToken
260+ t[2 ].cast <double >(), // beamThreshold
261+ t[3 ].cast <double >(), // lmWeight
262+ t[4 ].cast <double >(), // wordScore
263+ t[5 ].cast <double >(), // unkScore
264+ t[6 ].cast <double >(), // silScore
265+ t[7 ].cast <bool >(), // logAdd
266+ t[8 ].cast <CriterionType>() // criterionType
267+ };
268+ return opts;
269+ }));
269270
270271 py::class_<LexiconFreeDecoderOptions>(m, " LexiconFreeDecoderOptions" )
271272 .def (
@@ -294,34 +295,35 @@ PYBIND11_MODULE(flashlight_lib_text_decoder, m) {
294295 .def_readwrite (" log_add" , &LexiconFreeDecoderOptions::logAdd)
295296 .def_readwrite (
296297 " criterion_type" , &LexiconFreeDecoderOptions::criterionType)
297- .def (py::pickle (
298- [](const LexiconFreeDecoderOptions& p) { // __getstate__
299- return py::make_tuple (
300- p.beamSize ,
301- p.beamSizeToken ,
302- p.beamThreshold ,
303- p.lmWeight ,
304- p.silScore ,
305- p.logAdd ,
306- p.criterionType );
307- },
308- [](py::tuple t) { // __setstate__
309- if (t.size () != 7 ) {
310- throw std::runtime_error (
311- " Cannot run __setstate__ on LexiconFreeDecoderOptions - "
312- " insufficient arguments provided." );
313- }
314- LexiconFreeDecoderOptions opts = {
315- t[0 ].cast <int >(), // beamSize
316- t[1 ].cast <int >(), // beamSizeToken
317- t[2 ].cast <double >(), // beamThreshold
318- t[3 ].cast <double >(), // lmWeight
319- t[4 ].cast <double >(), // silScore
320- t[5 ].cast <bool >(), // logAdd
321- t[6 ].cast <CriterionType>() // criterionType
322- };
323- return opts;
324- }));
298+ .def (
299+ py::pickle (
300+ [](const LexiconFreeDecoderOptions& p) { // __getstate__
301+ return py::make_tuple (
302+ p.beamSize ,
303+ p.beamSizeToken ,
304+ p.beamThreshold ,
305+ p.lmWeight ,
306+ p.silScore ,
307+ p.logAdd ,
308+ p.criterionType );
309+ },
310+ [](py::tuple t) { // __setstate__
311+ if (t.size () != 7 ) {
312+ throw std::runtime_error (
313+ " Cannot run __setstate__ on LexiconFreeDecoderOptions - "
314+ " insufficient arguments provided." );
315+ }
316+ LexiconFreeDecoderOptions opts = {
317+ t[0 ].cast <int >(), // beamSize
318+ t[1 ].cast <int >(), // beamSizeToken
319+ t[2 ].cast <double >(), // beamThreshold
320+ t[3 ].cast <double >(), // lmWeight
321+ t[4 ].cast <double >(), // silScore
322+ t[5 ].cast <bool >(), // logAdd
323+ t[6 ].cast <CriterionType>() // criterionType
324+ };
325+ return opts;
326+ }));
325327
326328 py::class_<DecodeResult>(m, " DecodeResult" )
327329 .def (py::init<int >(), " length" _a)
@@ -404,38 +406,39 @@ PYBIND11_MODULE(flashlight_lib_text_decoder, m) {
404406 .def (
405407 " get_options" ,
406408 &LexiconFreeDecoder::getOptions)
407- .def (py::pickle (
408- [](const LexiconFreeDecoder& p) { // __getstate__
409- if (p.getAllFinalHypothesis ().size () != 0 ) {
410- throw std::runtime_error (
411- " LexiconFreeDecoder: cannot pickle decoder that has state" );
412- }
413- if (!std::dynamic_pointer_cast<ZeroLM>(p.getLMPtr ())) {
414- throw std::runtime_error (
415- " LexiconFreeDecoder: cannot pickle a decoder with an "
416- " integrated language model that is not ZeroLM" );
417- }
418- return py::make_tuple (
419- p.getOptions (),
420- p.getSilIdx (),
421- p.getBlankIdx (),
422- p.getTransitions ());
423- },
424- [](py::tuple t) { // __setstate__
425- if (t.size () != 4 ) {
426- throw std::runtime_error (
427- " Cannot run __setstate__ on LexiconFreeDecoder - "
428- " insufficient arguments provided." );
429- }
430-
431- return LexiconFreeDecoder (
432- t[0 ].cast <LexiconFreeDecoderOptions>(), // options
433- std::make_shared<ZeroLM>(), // lm
434- t[1 ].cast <int >(), // silIdx
435- t[2 ].cast <int >(), // blankIdx
436- t[3 ].cast <std::vector<float >>() // transitions
437- );
438- }));
409+ .def (
410+ py::pickle (
411+ [](const LexiconFreeDecoder& p) { // __getstate__
412+ if (p.getAllFinalHypothesis ().size () != 0 ) {
413+ throw std::runtime_error (
414+ " LexiconFreeDecoder: cannot pickle decoder that has state" );
415+ }
416+ if (!std::dynamic_pointer_cast<ZeroLM>(p.getLMPtr ())) {
417+ throw std::runtime_error (
418+ " LexiconFreeDecoder: cannot pickle a decoder with an "
419+ " integrated language model that is not ZeroLM" );
420+ }
421+ return py::make_tuple (
422+ p.getOptions (),
423+ p.getSilIdx (),
424+ p.getBlankIdx (),
425+ p.getTransitions ());
426+ },
427+ [](py::tuple t) { // __setstate__
428+ if (t.size () != 4 ) {
429+ throw std::runtime_error (
430+ " Cannot run __setstate__ on LexiconFreeDecoder - "
431+ " insufficient arguments provided." );
432+ }
433+
434+ return LexiconFreeDecoder (
435+ t[0 ].cast <LexiconFreeDecoderOptions>(), // options
436+ std::make_shared<ZeroLM>(), // lm
437+ t[1 ].cast <int >(), // silIdx
438+ t[2 ].cast <int >(), // blankIdx
439+ t[3 ].cast <std::vector<float >>() // transitions
440+ );
441+ }));
439442
440443 // Seq2seq Decoding
441444 py::class_<LexiconSeq2SeqDecoderOptions>(m, " LexiconSeq2SeqDecoderOptions" )
0 commit comments