KuangDW
add laser tool
2aebc50
%module sentencepiece
%include exception.i
%{
#include <atomic>
#include <iostream>
#include <algorithm>
#include <functional>
#include <limits>
#include <cmath>
#include <thread>
#include <vector>
#include <sentencepiece_processor.h>
#include <sentencepiece_trainer.h>
namespace {
PyObject* kUnicodeInput = reinterpret_cast<PyObject* >(0x1);
PyObject* kByteInput = reinterpret_cast<PyObject* >(0x2);
using BytesArray = std::vector<sentencepiece::util::bytes>;
inline void ReleaseResultObject(PyObject *obj) {
if (obj != nullptr && obj != kUnicodeInput && obj != kByteInput) {
Py_XDECREF(obj);
}
}
class PyInputString {
public:
explicit PyInputString(PyObject* obj) {
if (PyUnicode_Check(obj)) {
str_ = const_cast<char *>(PyUnicode_AsUTF8AndSize(obj, &size_));
input_type_ = kUnicodeInput;
} else if (PyBytes_Check(obj)) {
PyBytes_AsStringAndSize(obj, &str_, &size_);
input_type_ = kByteInput;
} else {
str_ = nullptr;
}
}
absl::string_view str() const { return absl::string_view(data(), size()); }
const char* data() const { return str_; }
Py_ssize_t size() const { return size_; }
bool IsAvalable() const { return str_ != nullptr; }
PyObject *input_type() const { return input_type_; }
static bool IsUnicode(PyObject *resultobj) {
return (resultobj == nullptr || resultobj == kUnicodeInput);
}
private:
PyObject* input_type_ = nullptr;
char* str_ = nullptr;
Py_ssize_t size_ = 0;
};
PyObject* MakePyOutputString(const std::string& output,
PyObject *resultobj) {
if (PyInputString::IsUnicode(resultobj)) {
return PyUnicode_FromStringAndSize(output.data(), output.size());
}
return PyBytes_FromStringAndSize(output.data(), output.size());
}
PyObject* MakePyOutputBytes(const sentencepiece::util::bytes& output) {
return PyBytes_FromStringAndSize(output.data(), output.size());
}
int ToSwigError(sentencepiece::util::StatusCode code) {
switch (code) {
case sentencepiece::util::StatusCode::kNotFound:
return SWIG_IOError;
case sentencepiece::util::StatusCode::kOutOfRange:
return SWIG_IndexError;
case sentencepiece::util::StatusCode::kInvalidArgument:
return SWIG_SyntaxError;
default:
return SWIG_RuntimeError;
}
return SWIG_RuntimeError;
}
class PySentenceIterator : public sentencepiece::SentenceIterator {
public:
PySentenceIterator(PyObject *iter) : iter_(iter) {
item_ = PyIter_Next(iter_);
CopyValue();
}
~PySentenceIterator() {
// Py_XDECREF(iter_);
}
bool done() const override {
return item_ == nullptr;
}
void Next() override {
item_ = PyIter_Next(iter_);
CopyValue();
}
const std::string &value() const override {
return value_;
}
sentencepiece::util::Status status() const override {
return status_;
}
private:
void CopyValue() {
if (item_ == nullptr) return;
const PyInputString ustring(item_);
if (ustring.IsAvalable()) {
const char *data = ustring.data();
size_t size = ustring.size();
while (size > 0) {
if (data[size - 1] == '\r' || data[size - 1] == '\n')
--size;
else
break;
}
value_.assign(data, size);
} else {
status_ = sentencepiece::util::Status(sentencepiece::util::StatusCode::kInternal,
"Not a string.");
}
Py_XDECREF(item_);
}
PyObject *iter_ = nullptr;
PyObject *item_ = nullptr;
std::string value_;
sentencepiece::util::Status status_;
};
inline void RewriteIds(const sentencepiece::SentencePieceProcessor &sp,
std::vector<int> *ids,
bool add_bos, bool add_eos, bool reverse, bool emit_unk_piece) {
if (!add_bos && !add_eos && !reverse) return;
if (reverse) std::reverse(ids->begin(), ids->end());
if (add_bos) ids->insert(ids->begin(), sp.bos_id());
if (add_eos) ids->push_back(sp.eos_id());
}
inline void RewriteIds(const sentencepiece::SentencePieceProcessor &sp,
std::vector<std::string> *pieces,
bool add_bos, bool add_eos, bool reverse, bool emit_unk_piece) {
if (!add_bos && !add_eos && !reverse && !emit_unk_piece) return;
if (reverse) std::reverse(pieces->begin(), pieces->end());
if (add_bos) pieces->insert(pieces->begin(), sp.IdToPiece(sp.bos_id()));
if (add_eos) pieces->push_back(sp.IdToPiece(sp.eos_id()));
if (emit_unk_piece) {
const auto &unk = sp.IdToPiece(sp.unk_id());
for (auto &piece : *pieces) {
const int id = sp.PieceToId(piece);
if (id == sp.unk_id()) {
piece = unk;
}
}
}
}
inline void RewriteIds(const sentencepiece::SentencePieceProcessor &sp,
sentencepiece::util::bytes *proto,
bool add_bos, bool add_eos, bool reverse, bool emit_unk_piece) {
if (add_bos || add_eos || reverse || emit_unk_piece) {
throw sentencepiece::util::Status(
sentencepiece::util::StatusCode::kUnimplemented,
"add_bos, add_eos, reverse, and emit_unk_piece is not supported in proto API");
}
}
inline void RewriteIds(const sentencepiece::SentencePieceProcessor &sp,
sentencepiece::ImmutableSentencePieceText *proto,
bool add_bos, bool add_eos, bool reverse, bool emit_unk_piece) {
if (add_bos || add_eos || reverse || emit_unk_piece) {
throw sentencepiece::util::Status(
sentencepiece::util::StatusCode::kUnimplemented,
"add_bos, add_eos, reverse, and emit_unk_piece is not supported in proto API");
}
}
inline void CheckIds(const std::vector<int> &ids, int num_pieces) {
for (int id : ids) {
if (id < 0 || id >= num_pieces) {
throw sentencepiece::util::Status(
sentencepiece::util::StatusCode::kOutOfRange,
"piece id is out of range.");
}
}
}
inline void CheckIds(const std::vector<absl::string_view> &ids, int num_pieces) {}
template <typename T>
inline void ConvertToUnicodeSpans(T *proto) {}
template <>
inline void ConvertToUnicodeSpans(sentencepiece::ImmutableSentencePieceText *proto) {
proto->ConvertToUnicodeSpans();
}
template <>
inline void ConvertToUnicodeSpans(sentencepiece::ImmutableNBestSentencePieceText *proto) {
proto->ConvertToUnicodeSpans();
}
class ThreadPool {
public:
explicit ThreadPool(size_t request_size) :
request_size_(request_size) {}
virtual ~ThreadPool() {
for (auto &task : tasks_) {
task.join();
}
}
void Schedule(std::function<void()> closure) {
static constexpr size_t kMinThreadSize = 2;
if (request_size_ < kMinThreadSize) {
closure();
} else {
tasks_.emplace_back(closure);
}
}
private:
size_t request_size_ = 0;
std::vector<std::thread> tasks_;
};
template <typename T>
inline void InitNumThreads(const std::vector<T> &ins, int *num_threads) {
if (*num_threads < 0) {
*num_threads = std::thread::hardware_concurrency();
}
*num_threads = std::max<int>(1,
std::min<int>({*num_threads,
static_cast<int>(ins.size()), 256}));
}
#define DEFINE_ENCODE_BATCH_FUNC_IMPL(FuncName, InType, OutType) \
std::vector<OutType> outs(ins.size()); \
InitNumThreads(ins, &num_threads); \
{ \
ThreadPool pool(ins.size()); \
std::atomic<size_t> index = 0; \
for (int n = 0; n < num_threads; ++n) { \
pool.Schedule([&]() { \
size_t i = 0; \
while ((i = std::atomic_fetch_add(&index, 1)) < outs.size()) { \
auto out = enable_sampling ? \
self->Sample##FuncName(ins[i], \
nbest_size, alpha) : \
self->FuncName(ins[i]); \
RewriteIds(*self, &out, add_bos, add_eos, reverse, \
emit_unk_piece); \
ConvertToUnicodeSpans(&out); \
outs[i] = std::move(out); \
} \
}); \
} \
} \
return outs;
#define DEFINE_DECODE_BATCH_FUNC_IMPL(FuncName, InType, OutType) \
std::vector<OutType> outs(ins.size()); \
InitNumThreads(ins, &num_threads); \
{ \
std::atomic<size_t> index = 0; \
ThreadPool pool(ins.size()); \
for (int n = 0; n < num_threads; ++n) { \
pool.Schedule([&]() { \
size_t i = 0; \
while ((i = std::atomic_fetch_add(&index, 1)) < outs.size()) { \
CheckIds(ins[i], self->GetPieceSize()); \
auto out = self->FuncName(ins[i]); \
ConvertToUnicodeSpans(&out); \
outs[i] = std::move(out); \
} \
}); \
} \
} \
return outs;
} // namespace
%}
%exception {
try {
$action
ReleaseResultObject(resultobj);
}
catch (const sentencepiece::util::Status &status) {
SWIG_exception(ToSwigError(status.code()), status.ToString().c_str());
}
}
%apply unsigned int { uint32_t }
%ignore sentencepiece::util::Status;
%ignore sentencepiece::util::StatusCode;
%ignore absl::string_view;
%ignore std::string_view;
%ignore sentencepiece::SentencePieceText;
%ignore sentencepiece::NormalizerSpec;
%ignore sentencepiece::TrainerSpec;
%ignore sentencepiece::SentencePieceProcessor::status;
%ignore sentencepiece::ImmutableSentencePieceText::mutable_proto;
%ignore sentencepiece::ImmutableSentencePieceText::pieces() const;
%ignore sentencepiece::ImmutableSentencePieceText::ConvertToUnicodeSpans;
%ignore sentencepiece::ImmutableNBestSentencePieceText::mutable_proto;
%ignore sentencepiece::ImmutableNBestSentencePieceText::nbests() const;
%ignore sentencepiece::ImmutableNBestSentencePieceText::ConvertToUnicodeSpans;
%ignore sentencepiece::SentencePieceProcessor::Encode;
%ignore sentencepiece::SentencePieceProcessor::SampleEncode;
%ignore sentencepiece::SentencePieceProcessor::NBestEncode;
%ignore sentencepiece::SentencePieceProcessor::SampleEncodeAndScore;
%ignore sentencepiece::SentencePieceProcessor::Decode;
%ignore sentencepiece::SentencePieceProcessor::EncodeAsPieces;
%ignore sentencepiece::SentencePieceProcessor::EncodeAsIds;
%ignore sentencepiece::SentencePieceProcessor::SampleEncodeAsIds;
%ignore sentencepiece::SentencePieceProcessor::SampleEncodeAsPieces;
%ignore sentencepiece::SentencePieceProcessor::NBestEncodeAsIds;
%ignore sentencepiece::SentencePieceProcessor::NBestEncodeAsPieces;
%ignore sentencepiece::SentencePieceProcessor::SampleEncodeAndScoreAsIds;
%ignore sentencepiece::SentencePieceProcessor::SampleEncodeAndScoreAsPieces;
%ignore sentencepiece::SentencePieceProcessor::DecodeIds;
%ignore sentencepiece::SentencePieceProcessor::DecodePieces;
%ignore sentencepiece::SentencePieceProcessor::EncodeAsSerializedProto;
%ignore sentencepiece::SentencePieceProcessor::SampleEncodeAsSerializedProto;
%ignore sentencepiece::SentencePieceProcessor::NBestEncodeAsSerializedProto;
%ignore sentencepiece::SentencePieceProcessor::SampleEncodeAndScoreAsSerializedProto;
%ignore sentencepiece::SentencePieceProcessor::DecodePiecesAsSerializedProto;
%ignore sentencepiece::SentencePieceProcessor::DecodeIdsAsSerializedProto;
%ignore sentencepiece::SentencePieceProcessor::EncodeAsImmutableProto;
%ignore sentencepiece::SentencePieceProcessor::SampleEncodeAsImmutableProto;
%ignore sentencepiece::SentencePieceProcessor::NBestEncodeAsImmutableProto;
%ignore sentencepiece::SentencePieceProcessor::SampleEncodeAndScoreAsImmutableProto;
%ignore sentencepiece::SentencePieceProcessor::DecodePiecesAsImmutableProto;
%ignore sentencepiece::SentencePieceProcessor::DecodeIdsAsImmutableProto;
%ignore sentencepiece::SentencePieceProcessor::Normalize;
%ignore sentencepiece::SentencePieceProcessor::NormalizeWithOffsets;
%ignore sentencepiece::SentencePieceProcessor::model_proto;
%ignore sentencepiece::SentencePieceProcessor::mutable_normalizer_spec;
%ignore sentencepiece::SentencePieceProcessor::Load;
%ignore sentencepiece::SentencePieceProcessor::LoadOrDie;
%ignore sentencepiece::SentencePieceProcessor::SetModel;
%ignore sentencepiece::SentencePieceProcessor::SetNormalizer;
%ignore sentencepiece::pretokenizer::PretokenizerForTrainingInterface;
%ignore sentencepiece::SentenceIterator;
%ignore sentencepiece::ConvertToUnicodeSpans;
%ignore sentencepiece::SentencePieceTrainer::Train;
%ignore sentencepiece::SentencePieceTrainer::GetNormalizerSpec;
%ignore sentencepiece::SentencePieceTrainer::PopulateNormalizerSpec;
%ignore sentencepiece::SentencePieceTrainer::MergeSpecsFromArgs;
%ignore sentencepiece::SentencePieceTrainer::SetProtoField;
%ignore sentencepiece::SentencePieceTrainer::PopulateModelTypeFromString;
%ignore sentencepiece::SentencePieceTrainer::PieceProcecssor;
%ignore sentencepiece::SentencePieceTrainer::SetPretokenizerForTraining;
%ignore sentencepiece::SentencePieceTrainer::GetPretokenizerForTraining;
%ignore sentencepiece::ConvertToUnicodeAlignment;
%ignore sentencepiece::SentencePieceNormalizer::Load;
%ignore sentencepiece::SentencePieceNormalizer::Normalize;
%ignore sentencepiece::SentencePieceNormalizer::mutable_normalizer_spec;
%ignore sentencepiece::io::LoadModelProto;
%ignore sentencepiece::io::SaveModelProto;
%extend sentencepiece::SentencePieceProcessor {
sentencepiece::util::Status LoadFromFile(absl::string_view arg) {
return $self->Load(arg);
}
/////////////////////////////////////////////////////////////////////////////
// EncodeAs* (Single request)
std::vector<int> _EncodeAsIds(absl::string_view text,
bool enable_sampling,
int nbest_size, float alpha,
bool add_bos, bool add_eos, bool reverse,
bool emit_unk_piece) const {
auto ids = enable_sampling ?
$self->SampleEncodeAsIds(text, nbest_size, alpha) :
$self->EncodeAsIds(text);
RewriteIds(*$self, &ids, add_bos, add_eos, reverse, emit_unk_piece);
return ids;
}
std::vector<std::string> _EncodeAsPieces(absl::string_view text,
bool enable_sampling,
int nbest_size, float alpha,
bool add_bos, bool add_eos, bool reverse,
bool emit_unk_piece) const {
auto pieces = enable_sampling ?
$self->SampleEncodeAsPieces(text, nbest_size, alpha) :
$self->EncodeAsPieces(text);
RewriteIds(*$self, &pieces, add_bos, add_eos, reverse, emit_unk_piece);
return pieces;
}
sentencepiece::util::bytes _EncodeAsSerializedProto(absl::string_view text,
bool enable_sampling,
int nbest_size, float alpha,
bool add_bos, bool add_eos, bool reverse,
bool emit_unk_piece) const {
auto proto = enable_sampling ?
$self->SampleEncodeAsSerializedProto(text, nbest_size, alpha) :
$self->EncodeAsSerializedProto(text);
RewriteIds(*$self, &proto, add_bos, add_eos, reverse, emit_unk_piece);
return proto;
}
sentencepiece::ImmutableSentencePieceText
_EncodeAsImmutableProto(absl::string_view text,
bool enable_sampling,
int nbest_size, float alpha,
bool add_bos, bool add_eos, bool reverse,
bool emit_unk_piece) const {
auto proto = enable_sampling ?
$self->SampleEncodeAsImmutableProto(text, nbest_size, alpha) :
$self->EncodeAsImmutableProto(text);
proto.ConvertToUnicodeSpans();
RewriteIds(*$self, &proto, add_bos, add_eos, reverse, emit_unk_piece);
return proto;
}
/////////////////////////////////////////////////////////////////////////////
// EncodeAs* (Batch request)
std::vector<std::vector<int>> _EncodeAsIdsBatch(
const std::vector<absl::string_view> &ins, int num_threads,
bool enable_sampling, int nbest_size, float alpha,
bool add_bos, bool add_eos, bool reverse,
bool emit_unk_piece) const {
DEFINE_ENCODE_BATCH_FUNC_IMPL(EncodeAsIds,
absl::string_view, std::vector<int>);
}
std::vector<std::vector<std::string>> _EncodeAsPiecesBatch(
const std::vector<absl::string_view> &ins, int num_threads,
bool enable_sampling, int nbest_size, float alpha,
bool add_bos, bool add_eos, bool reverse,
bool emit_unk_piece) const {
DEFINE_ENCODE_BATCH_FUNC_IMPL(EncodeAsPieces,
absl::string_view, std::vector<std::string>);
}
BytesArray _EncodeAsSerializedProtoBatch(
const std::vector<absl::string_view> &ins, int num_threads,
bool enable_sampling, int nbest_size, float alpha,
bool add_bos, bool add_eos, bool reverse,
bool emit_unk_piece) const {
DEFINE_ENCODE_BATCH_FUNC_IMPL(EncodeAsSerializedProto,
absl::string_view,
sentencepiece::util::bytes);
}
std::vector<sentencepiece::ImmutableSentencePieceText>
_EncodeAsImmutableProtoBatch(
const std::vector<absl::string_view> &ins, int num_threads,
bool enable_sampling, int nbest_size, float alpha,
bool add_bos, bool add_eos, bool reverse,
bool emit_unk_piece) const {
DEFINE_ENCODE_BATCH_FUNC_IMPL(EncodeAsImmutableProto,
absl::string_view,
sentencepiece::ImmutableSentencePieceText);
}
/////////////////////////////////////////////////////////////////////////////
// DecodeAs* (Single request)
std::string _DecodeIds(const std::vector<int> &ids) const {
CheckIds(ids, $self->GetPieceSize());
return $self->DecodeIds(ids);
}
sentencepiece::util::bytes _DecodeIdsAsBytes(const std::vector<int> &ids) const {
CheckIds(ids, $self->GetPieceSize());
return $self->DecodeIds(ids);
}
std::string _DecodePieces(const std::vector<absl::string_view> &pieces) const {
return $self->DecodePieces(pieces);
}
sentencepiece::util::bytes _DecodeIdsAsSerializedProto(
const std::vector<int> &ids) const {
CheckIds(ids, $self->GetPieceSize());
return $self->DecodeIdsAsSerializedProto(ids);
}
sentencepiece::util::bytes _DecodePiecesAsSerializedProto(
const std::vector<absl::string_view> &pieces) const {
CheckIds(pieces, $self->GetPieceSize());
return $self->DecodePiecesAsSerializedProto(pieces);
}
sentencepiece::ImmutableSentencePieceText _DecodeIdsAsImmutableProto(
const std::vector<int> &ids) const {
CheckIds(ids, $self->GetPieceSize());
auto proto = $self->DecodeIdsAsImmutableProto(ids);
proto.ConvertToUnicodeSpans();
return proto;
}
sentencepiece::ImmutableSentencePieceText _DecodePiecesAsImmutableProto(
const std::vector<absl::string_view> &pieces) const {
CheckIds(pieces, $self->GetPieceSize());
auto proto= $self->DecodePiecesAsImmutableProto(pieces);
proto.ConvertToUnicodeSpans();
return proto;
}
/////////////////////////////////////////////////////////////////////////////
// DecodeAs* (Batch request)
std::vector<std::string> _DecodeIdsBatch(
const std::vector<std::vector<int>> &ins, int num_threads) const {
DEFINE_DECODE_BATCH_FUNC_IMPL(DecodeIds, int, std::string);
}
BytesArray _DecodeIdsAsBytesBatch(
const std::vector<std::vector<int>> &ins, int num_threads) const {
DEFINE_DECODE_BATCH_FUNC_IMPL(DecodeIds, int, std::string);
}
BytesArray _DecodeIdsAsSerializedProtoBatch(
const std::vector<std::vector<int>> &ins, int num_threads) const {
DEFINE_DECODE_BATCH_FUNC_IMPL(DecodeIdsAsSerializedProto, int,
sentencepiece::util::bytes);
}
std::vector<sentencepiece::ImmutableSentencePieceText>
_DecodeIdsAsImmutableProtoBatch(
const std::vector<std::vector<int>> &ins, int num_threads) const {
DEFINE_DECODE_BATCH_FUNC_IMPL(DecodeIdsAsImmutableProto, int,
sentencepiece::ImmutableSentencePieceText);
}
std::vector<std::string> _DecodePiecesBatch(
const std::vector<std::vector<absl::string_view>> &ins, int num_threads) const {
DEFINE_DECODE_BATCH_FUNC_IMPL(DecodePieces, std::string, std::string);
}
BytesArray _DecodePiecesAsSerializedProtoBatch(
const std::vector<std::vector<absl::string_view>> &ins, int num_threads) const {
DEFINE_DECODE_BATCH_FUNC_IMPL(DecodePiecesAsSerializedProto, std::string,
sentencepiece::util::bytes);
}
std::vector<sentencepiece::ImmutableSentencePieceText>
_DecodePiecesAsImmutableProtoBatch(
const std::vector<std::vector<absl::string_view>> &ins, int num_threads) const {
DEFINE_DECODE_BATCH_FUNC_IMPL(DecodePiecesAsImmutableProto, std::string,
sentencepiece::ImmutableSentencePieceText);
}
////////////////////////////////////////////////////////////////////////////
// NBestEncodeAs* (Single request)
std::vector<std::vector<int>>
_NBestEncodeAsIds(absl::string_view text,
int nbest_size,
bool add_bos, bool add_eos, bool reverse,
bool emit_unk_piece) const {
auto idss = $self->NBestEncodeAsIds(text, nbest_size);
for (auto &ids : idss) {
RewriteIds(*$self, &ids, add_bos, add_eos, reverse, emit_unk_piece);
}
return idss;
}
std::vector<std::vector<std::string>>
_NBestEncodeAsPieces(absl::string_view text,
int nbest_size,
bool add_bos, bool add_eos, bool reverse,
bool emit_unk_piece) const {
auto piecess = $self->NBestEncodeAsPieces(text, nbest_size);
for (auto &pieces : piecess) {
RewriteIds(*$self, &pieces, add_bos, add_eos, reverse, emit_unk_piece);
}
return piecess;
}
sentencepiece::util::bytes
_NBestEncodeAsSerializedProto(absl::string_view text,
int nbest_size,
bool add_bos, bool add_eos, bool reverse,
bool emit_unk_piece) const {
RewriteIds(*$self, static_cast<sentencepiece::util::bytes *>(nullptr),
add_bos, add_eos, reverse, emit_unk_piece);
return $self->NBestEncodeAsSerializedProto(text, nbest_size);
}
sentencepiece::ImmutableNBestSentencePieceText
_NBestEncodeAsImmutableProto(absl::string_view text,
int nbest_size,
bool add_bos, bool add_eos, bool reverse,
bool emit_unk_piece) const {
RewriteIds(*$self, static_cast<sentencepiece::ImmutableSentencePieceText *>(nullptr),
add_bos, add_eos, reverse, emit_unk_piece);
auto proto = $self->NBestEncodeAsImmutableProto(text, nbest_size);
proto.ConvertToUnicodeSpans();
return proto;
}
/////////////////////////////////////////////////////////////////////////////
// SampleEncodeAndScoreAs* (Single request)
std::vector<std::pair<std::vector<int>, float>>
_SampleEncodeAndScoreAsIds(absl::string_view text,
int num_samples, float alpha, bool wor,
bool include_best,
bool add_bos, bool add_eos, bool reverse,
bool emit_unk_piece) const {
auto idss = $self->SampleEncodeAndScoreAsIds(text, num_samples,
alpha, wor, include_best);
for (auto &ids : idss) {
RewriteIds(*$self, &ids.first, add_bos, add_eos, reverse, emit_unk_piece);
}
return idss;
}
std::vector<std::pair<std::vector<std::string>, float>>
_SampleEncodeAndScoreAsPieces(absl::string_view text,
int num_samples, float alpha, bool wor,
bool include_best,
bool add_bos, bool add_eos, bool reverse,
bool emit_unk_piece) const {
auto piecess = $self->SampleEncodeAndScoreAsPieces(text, num_samples,
alpha, wor, include_best);
for (auto &pieces : piecess) {
RewriteIds(*$self, &pieces.first, add_bos, add_eos, reverse, emit_unk_piece);
}
return piecess;
}
sentencepiece::util::bytes
_SampleEncodeAndScoreAsSerializedProto(absl::string_view text,
int num_samples, float alpha, bool wor,
bool include_best,
bool add_bos, bool add_eos, bool reverse,
bool emit_unk_piece) const {
RewriteIds(*$self, static_cast<sentencepiece::util::bytes *>(nullptr),
add_bos, add_eos, reverse, emit_unk_piece);
return $self->SampleEncodeAndScoreAsSerializedProto(text, num_samples,
alpha, wor, include_best);
}
sentencepiece::ImmutableNBestSentencePieceText
_SampleEncodeAndScoreAsImmutableProto(absl::string_view text,
int num_samples, float alpha, bool wor,
bool include_best,
bool add_bos, bool add_eos, bool reverse,
bool emit_unk_piece) const {
RewriteIds(*$self, static_cast<sentencepiece::util::bytes *>(nullptr),
add_bos, add_eos, reverse, emit_unk_piece);
auto proto = $self->SampleEncodeAndScoreAsImmutableProto(text, num_samples,
alpha, wor, include_best);
proto.ConvertToUnicodeSpans();
return proto;
}
// Normalize
std::string _Normalize(absl::string_view text) {
return $self->Normalize(text);
}
std::pair<std::string, std::vector<size_t>> _NormalizeWithOffsets(absl::string_view text) {
std::pair<std::string, std::vector<size_t>> result;
$self->Normalize(text, &result.first, &result.second).IgnoreError();
return result;
}
// Calculate Entropy
float _CalculateEntropy(absl::string_view text, float alpha) {
return $self->CalculateEntropy(text, alpha);
}
std::vector<float> _CalculateEntropyBatch(const std::vector<absl::string_view> &ins,
float alpha, int num_threads) {
std::vector<float> outs(ins.size());
InitNumThreads(ins, &num_threads);
{
ThreadPool pool(ins.size());
std::atomic<size_t> index = 0;
for (int n = 0; n < num_threads; ++n) {
pool.Schedule([&]() {
size_t i = 0;
while ((i = std::atomic_fetch_add(&index, 1)) < outs.size()) {
outs[i] = self->CalculateEntropy(ins[i], alpha);
}
});
}
}
return outs;
}
// override normalizer_spec
sentencepiece::util::Status _OverrideNormalizerSpec(
const std::unordered_map<std::string, std::string> &args) {
sentencepiece::util::Status status;
for (const auto &[key, value] : args) {
status = sentencepiece::SentencePieceTrainer::SetProtoField(
key, value,
$self->mutable_normalizer_spec());
if (!status.ok()) return status;
}
return status;
}
%pythoncode {
def Init(self,
model_file=None,
model_proto=None,
out_type=int,
add_bos=False,
add_eos=False,
reverse=False,
emit_unk_piece=False,
enable_sampling=False,
nbest_size=-1,
alpha=0.1,
num_threads=-1):
"""Initialzie sentencepieceProcessor.
Args:
model_file: The sentencepiece model file path.
model_proto: The sentencepiece model serialized proto.
out_type: output type. int or str.
add_bos: Add <s> to the result (Default = false)
add_eos: Add </s> to the result (Default = false) <s>/</s> is added after
reversing (if enabled).
reverse: Reverses the tokenized sequence (Default = false)
emit_unk_piece: Emits the unk literal string (Default = false)
nbest_size: sampling parameters for unigram. Invalid in BPE-Dropout.
nbest_size = {0,1}: No sampling is performed.
nbest_size > 1: samples from the nbest_size results.
nbest_size < 0: assuming that nbest_size is infinite and samples
from the all hypothesis (lattice) using
forward-filtering-and-backward-sampling algorithm.
alpha: Soothing parameter for unigram sampling, and dropout probability of
merge operations for BPE-dropout.
num_threads: number of threads in batch processing (Default = -1, auto-detected)
"""
_sentencepiece_processor_init_native(self)
self._out_type = out_type
self._add_bos = add_bos
self._add_eos = add_eos
self._reverse = reverse
self._emit_unk_piece = emit_unk_piece
self._enable_sampling = enable_sampling
self._nbest_size = nbest_size
self._alpha = alpha
self._num_threads = num_threads
if model_file or model_proto:
self.Load(model_file=model_file, model_proto=model_proto)
def Encode(self,
input,
out_type=None,
add_bos=None,
add_eos=None,
reverse=None,
emit_unk_piece=None,
enable_sampling=None,
nbest_size=None,
alpha=None,
num_threads=None):
"""Encode text input to segmented ids or tokens.
Args:
input: input string. accepsts list of string.
out_type: output type. int or str.
add_bos: Add <s> to the result (Default = false)
add_eos: Add </s> to the result (Default = false) <s>/</s> is added after
reversing (if enabled).
reverse: Reverses the tokenized sequence (Default = false)
emit_unk_piece: Emits the unk literal string (Default = false)
nbest_size: sampling parameters for unigram. Invalid in BPE-Dropout.
nbest_size = {0,1}: No sampling is performed.
nbest_size > 1: samples from the nbest_size results.
nbest_size < 0: assuming that nbest_size is infinite and samples
from the all hypothesis (lattice) using
forward-filtering-and-backward-sampling algorithm.
alpha: Soothing parameter for unigram sampling, and merge probability for
BPE-dropout (probablity 'p' in BPE-dropout paper).
num_threads: the number of threads used in the batch processing (Default = -1).
"""
if out_type is None:
out_type = self._out_type
if add_bos is None:
add_bos = self._add_bos
if add_eos is None:
add_eos = self._add_eos
if reverse is None:
reverse = self._reverse
if emit_unk_piece is None:
emit_unk_piece = self._emit_unk_piece
if enable_sampling is None:
enable_sampling = self._enable_sampling
if nbest_size is None:
nbest_size = self._nbest_size
if alpha is None:
alpha = self._alpha
if num_threads is None:
num_threads = self._num_threads
if enable_sampling == True and (nbest_size is None or nbest_size == 0 or
nbest_size == 1 or alpha is None):
raise RuntimeError(
'When enable_sampling is True, We must specify "nbest_size > 1" or "nbest_size = -1", '
'and "alpha". "nbest_size" is enabled only on unigram mode ignored in BPE-dropout. '
'when "nbest_size = -1" , this method samples from all candidates on the lattice '
'instead of nbest segmentations.'
)
if num_threads is None or type(num_threads) is not int:
raise RuntimeError('num_threads must be int')
if type(input) is list:
if out_type is int:
return self._EncodeAsIdsBatch(input, num_threads, enable_sampling, nbest_size,
alpha, add_bos, add_eos, reverse, emit_unk_piece)
if out_type is str:
return self._EncodeAsPiecesBatch(input, num_threads, enable_sampling, nbest_size,
alpha, add_bos, add_eos, reverse, emit_unk_piece)
if out_type == 'serialized_proto' or out_type == 'proto':
return self._EncodeAsSerializedProtoBatch(input, num_threads, enable_sampling, nbest_size,
alpha, add_bos, add_eos, reverse, emit_unk_piece)
if out_type == 'immutable_proto':
return self._EncodeAsImmutableProtoBatch(input, num_threads, enable_sampling, nbest_size,
alpha, add_bos, add_eos, reverse, emit_unk_piece)
if out_type is int:
return self._EncodeAsIds(input, enable_sampling, nbest_size,
alpha, add_bos, add_eos, reverse, emit_unk_piece)
if out_type is str:
return self._EncodeAsPieces(input, enable_sampling, nbest_size,
alpha, add_bos, add_eos, reverse, emit_unk_piece)
if out_type == 'serialized_proto' or out_type == 'proto':
return self._EncodeAsSerializedProto(input, enable_sampling, nbest_size,
alpha, add_bos, add_eos, reverse, emit_unk_piece)
if out_type == 'immutable_proto':
return self._EncodeAsImmutableProto(input, enable_sampling, nbest_size,
alpha, add_bos, add_eos, reverse, emit_unk_piece)
raise RuntimeError('unknown out_type={}'.format(out_type))
return None
def EncodeAsPieces(self, input, **kwargs):
return self.Encode(input=input, out_type=str, **kwargs)
def EncodeAsIds(self, input, **kwargs):
return self.Encode(input=input, out_type=int, **kwargs)
def EncodeAsSerializedProto(self, input, **kwargs):
return self.Encode(input=input, out_type='serialized_proto', **kwargs)
def EncodeAsImmutableProto(self, input, **kwargs):
return self.Encode(input=input, out_type='immutable_proto', **kwargs)
def SampleEncodeAsPieces(self, input, nbest_size=None, alpha=None, **kwargs):
return self.Encode(input=input, nbest_size=nbest_size, alpha=alpha,
out_type=str, enable_sampling=True, **kwargs)
def SampleEncodeAsIds(self, input, nbest_size=None, alpha=None,**kwargs):
return self.Encode(input=input, nbest_size=nbest_size, alpha=alpha,
out_type=int, enable_sampling=True, **kwargs)
def SampleEncodeAsSerializedProto(self, input, nbest_size=None, alpha=None, **kwargs):
return self.Encode(input=input, nbest_size=nbest_size, alpha=alpha,
out_type='serialized_proto', enable_sampling=True, **kwargs)
def SampleEncodeAsImmutableProto(self, input, nbest_size=None, alpha=None, **kwargs):
return self.Encode(input=input, nbest_size=nbest_size, alpha=alpha,
out_type='immutable_proto', enable_sampling=True, **kwargs)
def NBestEncode(self,
input,
out_type=None,
add_bos=None,
add_eos=None,
reverse=None,
emit_unk_piece=None,
nbest_size=None):
"""NBestEncode text input to segmented ids or tokens.
Args:
input: input string. accepsts list of string.
out_type: output type. int or str.
add_bos: Add <s> to the result (Default = false)
add_eos: Add </s> to the result (Default = false) <s>/</s> is added after reversing (if enabled).
reverse: Reverses the tokenized sequence (Default = false)
emit_unk_piece: Emits the unk literal string (Default = false)
nbest_size: nbest size
"""
if out_type is None:
out_type = self._out_type
if add_bos is None:
add_bos = self._add_bos
if add_eos is None:
add_eos = self._add_eos
if reverse is None:
reverse = self._reverse
if emit_unk_piece is None:
emit_unk_piece = self._emit_unk_piece
if nbest_size is None:
nbest_size = self._nbest_size
if nbest_size <= 0:
nbest_size=1
def _encode(text):
if out_type is int:
return self._NBestEncodeAsIds(text, nbest_size,
add_bos, add_eos, reverse, emit_unk_piece)
if out_type is str:
return self._NBestEncodeAsPieces(text, nbest_size,
add_bos, add_eos, reverse, emit_unk_piece)
if out_type == 'serialized_proto' or out_type == 'proto':
return self._NBestEncodeAsSerializedProto(text, nbest_size,
add_bos, add_eos, reverse, emit_unk_piece)
if out_type == 'immutable_proto':
return self._NBestEncodeAsImmutableProto(text, nbest_size,
add_bos, add_eos, reverse, emit_unk_piece)
raise RuntimeError('unknown out_type')
if type(input) is list:
return [_encode(n) for n in input]
return _encode(input)
def NBestEncodeAsPieces(self, input, nbest_size=None, **kwargs):
return self.NBestEncode(input=input, nbest_size=nbest_size,
out_type=str, **kwargs)
def NBestEncodeAsIds(self, input, nbest_size=None, **kwargs):
return self.NBestEncode(input=input, nbest_size=nbest_size,
out_type=int, **kwargs)
def NBestEncodeAsSerializedProto(self, input, nbest_size=None, **kwargs):
return self.NBestEncode(input=input, nbest_size=nbest_size,
out_type='serialized_proto', **kwargs)
def NBestEncodeAsImmutableProto(self, input, nbest_size=None, **kwargs):
return self.NBestEncode(input=input, nbest_size=nbest_size,
out_type='immutable_proto', **kwargs)
def SampleEncodeAndScore(self,
input,
out_type=None,
add_bos=None,
add_eos=None,
reverse=None,
emit_unk_piece=None,
num_samples=None,
alpha=None,
wor=None,
include_best=None):
"""SampleEncodeAndScore text input to segmented ids or tokens.
Args:
input: input string. accepsts list of string.
out_type: output type. int or str or 'serialized_proto' or 'immutable_proto'
add_bos: Add <s> to the result (Default = false)
add_eos: Add </s> to the result (Default = false) <s>/</s> is added after reversing (if enabled).
reverse: Reverses the tokenized sequence (Default = false)
emit_unk_piece: Emits the unk literal string (Default = false)
num_samples: How many samples to return (Default = 1)
alpha: inverse temperature for sampling
wor: whether to sample without replacement (Default = false)
include_best: whether to include the best tokenization, requires wor=True (Default = false)
"""
if out_type is None:
out_type = self._out_type
if add_bos is None:
add_bos = self._add_bos
if add_eos is None:
add_eos = self._add_eos
if reverse is None:
reverse = self._reverse
if emit_unk_piece is None:
emit_unk_piece = self._emit_unk_piece
if num_samples is None:
num_samples = 1
if alpha is None:
alpha = 1.
if wor is None:
wor = False
if include_best is None:
include_best = False
if num_samples <= 0:
raise RuntimeError('num_examples must be positive')
if include_best and not wor:
raise RuntimeError('When include_best is True, We must specify "wor = True".')
def _encode(text):
if out_type is int:
return self._SampleEncodeAndScoreAsIds(text, num_samples, alpha, wor, include_best,
add_bos, add_eos, reverse, emit_unk_piece)
if out_type is str:
return self._SampleEncodeAndScoreAsPieces(text, num_samples, alpha, wor, include_best,
add_bos, add_eos, reverse, emit_unk_piece)
if out_type == 'serialized_proto' or out_type == 'proto':
return self._SampleEncodeAndScoreAsSerializedProto(text, num_samples, alpha, wor, include_best,
add_bos, add_eos, reverse, emit_unk_piece)
if out_type == 'immutable_proto':
return self._SampleEncodeAndScoreAsImmutableProto(text, num_samples, alpha, wor, include_best,
add_bos, add_eos, reverse, emit_unk_piece)
raise RuntimeError('unknown output type')
if type(input) is list:
return [_encode(n) for n in input]
return _encode(input)
def SampleEncodeAndScoreAsPieces(self, input, num_samples=None, alpha=None, **kwargs):
return self.SampleEncodeAndScore(input=input, num_samples=num_samples, alpha=alpha,
out_type=str, **kwargs)
def SampleEncodeAndScoreAsIds(self, input, num_samples=None, alpha=None, **kwargs):
return self.SampleEncodeAndScore(input=input, num_samples=num_samples, alpha=alpha,
out_type=int, **kwargs)
def SampleEncodeAndScoreAsSerializedProto(self, input, num_samples=None, alpha=None, **kwargs):
return self.SampleEncodeAndScore(input=input, num_samples=num_samples, alpha=alpha,
out_type='serialized_proto', **kwargs)
def SampleEncodeAndScoreAsImmutableProto(self, input, num_samples=None, alpha=None, **kwargs):
return self.SampleEncodeAndScore(input=input, num_samples=num_samples, alpha=alpha,
out_type='immutable_proto', **kwargs)
def Decode(self, input, out_type=str, num_threads=None):
"""Decode processed id or token sequences.
Args:
out_type: output type. str, bytes or 'serialized_proto' or 'immutable_proto' (Default = str)
num_threads: the number of threads used in the batch processing (Default = -1).
"""
if num_threads is None:
num_threads = self._num_threads
if num_threads is None or type(num_threads) is not int:
raise RuntimeError('num_threads must be int')
if not input:
return ''
if out_type is str:
if type(input) is int:
return self._DecodeIds([input])
if type(input) is str:
return self._DecodePieces([input])
if type(input) is list:
if len(input) == 0 or type(input[0]) is int:
return self._DecodeIds(input)
if type(input[0]) is str:
return self._DecodePieces(input)
if type(input[0]) is list:
if len(input[0]) == 0 or type(input[0][0]) is int:
return self._DecodeIdsBatch(input, num_threads)
if type(input[0][0]) is str:
return self._DecodePiecesBatch(input, num_threads)
if out_type is bytes:
if type(input) is int:
return self._DecodeIdsAsBytes([input])
if type(input) is str:
return self._DecodePieces([input])
if type(input) is list:
if len(input) == 0 or type(input[0]) is int:
return self._DecodeIdsAsBytes(input)
if type(input[0]) is str:
return self._DecodePieces(input)
if type(input[0]) is list:
if len(input[0]) == 0 or type(input[0][0]) is int:
return self._DecodeIdsAsBytesBatch(input, num_threads)
if type(input[0][0]) is str:
return self._DecodePiecesBatch(input, num_threads)
if out_type == 'serialized_proto':
if type(input) is int:
return self._DecodeIdsAsSerializedProto([input])
if type(input) is str:
return self._DecodePiecesAsSerializedProto([input])
if type(input) is list:
if len(input) == 0 or type(input[0]) is int:
return self._DecodeIdsAsSerializedProto(input)
if type(input[0]) is str:
return self._DecodePiecesAsSerializedProto(input)
if type(input[0]) is list:
if len(input[0]) == 0 or type(input[0][0]) is int:
return self._DecodeIdsAsSerializedProtoBatch(input, num_threads)
if type(input[0][0]) is str:
return self._DecodePiecesAsSerializedProtoBatch(input, num_threads)
if out_type == 'immutable_proto':
if type(input) is int:
return self._DecodeIdsAsImmutableProto([input])
if type(input) is str:
return self._DecodePiecesAsImmutableProto([input])
if type(input) is list:
if len(input) == 0 or type(input[0]) is int:
return self._DecodeIdsAsImmutableProto(input)
if type(input[0]) is str:
return self._DecodePiecesAsImmutableProto(input)
if type(input[0]) is list:
if len(input[0]) == 0 or type(input[0][0]) is int:
return self._DecodeIdsAsImmutableProtoBatch(input, num_threads)
if type(input[0][0]) is str:
return self._DecodePiecesAsImmutableProtoBatch(input, num_threads)
raise RuntimeError('unknown output or input type')
return None
def DecodePieces(self, input, out_type=str, **kwargs):
return self.Decode(input=input, out_type=out_type, **kwargs)
def DecodeIds(self, input, out_type=str, **kwargs):
return self.Decode(input=input, out_type=out_type, **kwargs)
def DecodePiecesAsSerializedProto(self, input, out_type='serialized_proto', **kwargs):
return self.Decode(input=input, out_type=out_type, **kwargs)
def DecodeIdsAsSerializedProto(self, input, out_type='serialized_proto', **kwargs):
return self.Decode(input=input, out_type=out_type, **kwargs)
def DecodePiecesAsImmutableProto(self, input, out_type='immutable_proto', **kwargs):
return self.Decode(input=input, out_type=out_type, **kwargs)
def DecodeIdsAsImmutableProto(self, input, out_type='immutable_proto', **kwargs):
return self.Decode(input=input, out_type=out_type, **kwargs)
def CalculateEntropy(self, input, alpha, num_threads=None):
"""Calculate sentence entropy"""
if type(input) is list:
if num_threads is None:
num_threads = self._num_threads
if num_threads is None or type(num_threads) is not int:
raise RuntimeError('num_threads must be int')
return self._CalculateEntropyBatch(input, alpha, num_threads)
return self._CalculateEntropy(input, alpha)
def Normalize(self, input, with_offsets=None):
def _normalize(text):
if with_offsets:
return self._NormalizeWithOffsets(text)
return self._Normalize(text)
if type(input) is list:
return [_normalize(x) for x in input]
return _normalize(input)
def OverrideNormalizerSpec(self, **kwargs):
new_kwargs = {}
for key, value in kwargs.items():
new_kwargs[key] = str(value)
return self._OverrideNormalizerSpec(new_kwargs)
def piece_size(self):
return self.GetPieceSize()
def vocab_size(self):
return self.GetPieceSize()
def __getstate__(self):
return self.serialized_model_proto()
def __setstate__(self, serialized_model_proto):
self.__init__()
self.LoadFromSerializedProto(serialized_model_proto)
def __len__(self):
return self.GetPieceSize()
def __getitem__(self, piece):
return self.PieceToId(piece)
def Load(self, model_file=None, model_proto=None):
"""Overwride SentencePieceProcessor.Load to support both model_file and model_proto.
Args:
model_file: The sentencepiece model file path.
model_proto: The sentencepiece model serialized proto. Either `model_file`
or `model_proto` must be set.
"""
if model_file and model_proto:
raise RuntimeError('model_file and model_proto must be exclusive.')
if model_proto:
return self.LoadFromSerializedProto(model_proto)
return self.LoadFromFile(model_file)
}
}
%extend sentencepiece::SentencePieceTrainer {
static void _TrainFromString(absl::string_view arg) {
const auto _status = sentencepiece::SentencePieceTrainer::Train(arg);
if (!_status.ok()) throw _status;
return;
}
static void _TrainFromMap(const std::unordered_map<std::string, std::string> &args) {
const auto _status = sentencepiece::SentencePieceTrainer::Train(args);
if (!_status.ok()) throw _status;
return;
}
static void _TrainFromMap2(const std::unordered_map<std::string, std::string> &args,
SentenceIterator *iter) {
const auto _status = sentencepiece::SentencePieceTrainer::Train(args, iter);
if (!_status.ok()) throw _status;
return;
}
static sentencepiece::util::bytes _TrainFromMap3(const std::unordered_map<std::string, std::string> &args) {
sentencepiece::util::bytes model_proto;
const auto _status = sentencepiece::SentencePieceTrainer::Train(args, nullptr, &model_proto);
if (!_status.ok()) throw _status;
return model_proto;
}
static sentencepiece::util::bytes _TrainFromMap4(const std::unordered_map<std::string, std::string> &args,
SentenceIterator *iter) {
sentencepiece::util::bytes model_proto;
const auto _status = sentencepiece::SentencePieceTrainer::Train(args, iter, &model_proto);
if (!_status.ok()) throw _status;
return model_proto;
}
%pythoncode {
@staticmethod
def _Train(arg=None, **kwargs):
"""Train Sentencepiece model. Accept both kwargs and legacy string arg."""
if arg is not None and type(arg) is str:
return SentencePieceTrainer._TrainFromString(arg)
def _encode(value):
"""Encode value to CSV.."""
if type(value) is list:
if sys.version_info[0] == 3:
f = StringIO()
else:
f = BytesIO()
writer = csv.writer(f, lineterminator='')
writer.writerow([str(v) for v in value])
return f.getvalue()
else:
return str(value)
sentence_iterator = None
model_writer = None
new_kwargs = {}
for key, value in kwargs.items():
if key in ['sentence_iterator', 'sentence_reader']:
sentence_iterator = value
elif key in ['model_writer']:
model_writer = value
else:
new_kwargs[key] = _encode(value)
if model_writer:
if sentence_iterator:
model_proto = SentencePieceTrainer._TrainFromMap4(new_kwargs,
sentence_iterator)
else:
model_proto = SentencePieceTrainer._TrainFromMap3(new_kwargs)
model_writer.write(model_proto)
else:
if sentence_iterator:
return SentencePieceTrainer._TrainFromMap2(new_kwargs, sentence_iterator)
else:
return SentencePieceTrainer._TrainFromMap(new_kwargs)
return None
@staticmethod
def Train(arg=None, logstream=None, **kwargs):
with _LogStream(ostream=logstream):
SentencePieceTrainer._Train(arg=arg, **kwargs)
}
}
%extend sentencepiece::SentencePieceNormalizer {
sentencepiece::util::Status LoadFromFile(absl::string_view arg) {
return $self->Load(arg);
}
std::string _Normalize(absl::string_view text) {
std::string result;
const auto _status = $self->Normalize(text, &result);
if (!_status.ok()) throw _status;
return result;
}
std::pair<std::string, std::vector<size_t>> _NormalizeWithOffsets(absl::string_view text) {
std::pair<std::string, std::vector<size_t>> result;
const auto _status = $self->Normalize(text, &result.first, &result.second);
if (!_status.ok()) throw _status;
return result;
}
void _SetProtoField(absl::string_view name, bool value) {
sentencepiece::SentencePieceTrainer::SetProtoField(
name,
value ? "1" : "0",
$self->mutable_normalizer_spec()).IgnoreError();
}
%pythoncode %{
def Init(self,
model_file=None,
model_proto=None,
rule_tsv=None,
rule_name=None,
add_dummy_prefix=False,
escape_whitespaces=False,
remove_extra_whitespaces=False):
"""Initialzie sentencePieceNormalizer.
Args:
model_file: The sentencepiece model file path.
model_proto: The sentencepiece model serialized proto.
rule_tsv: The normalization rule file in TSV format.
rule_name: Pre-defined normalization name.
add_dummy_prefix: add dummy prefix.
escape_whitespaces: escape whitespaces.
remove_extra_whitespaces: remove extra whitespaces.
"""
_sentencepiece_normalizer_init_native(self)
if model_file:
status = self.LoadFromFile(model_file)
elif model_proto:
status = self.LoadFromSerializedProto(model_proto)
elif rule_tsv:
status = self.LoadFromRuleTSV(rule_tsv)
elif rule_name:
status = self.LoadFromRuleName(rule_name)
else:
raise RuntimeError('no model is specified')
if status:
self._SetProtoField('add_dummy_prefix', add_dummy_prefix)
self._SetProtoField('escape_whitespaces', escape_whitespaces)
self._SetProtoField('remove_extra_whitespaces', remove_extra_whitespaces)
def Normalize(self, input, with_offsets=None):
def _normalize(text):
if with_offsets:
return self._NormalizeWithOffsets(text)
return self._Normalize(text)
if type(input) is list:
return [_normalize(x) for x in input]
return _normalize(input)
def __getstate__(self):
return self.serialized_model_proto()
def __setstate__(self, serialized_model_proto):
self.__init__()
self.LoadFromSerializedProto(serialized_model_proto)
%}
}
%extend sentencepiece::ImmutableSentencePieceText_ImmutableSentencePiece {
const sentencepiece::util::bytes& _surface_as_bytes() const {
return $self->surface();
}
const sentencepiece::util::bytes& _piece_as_bytes() const {
return $self->piece();
}
%rename(_piece) piece;
%rename(_piece_as_bytes) piece_as_bytes;
%rename(_id) id;
%rename(_surface) surface;
%rename(_surface_as_bytes) surface_as_bytes;
%rename(_begin) begin;
%rename(_end) end;
%pythoncode %{
piece = property(_piece)
piece_as_bytes = property(_piece_as_bytes)
surface = property(_surface)
surface_as_bytes = property(_surface_as_bytes)
id = property(_id)
begin = property(_begin)
end = property(_end)
def __str__(self):
return ('piece: \"{}\"\n'
'id: {}\n'
'surface: \"{}\"\n'
'begin: {}\n'
'end: {}\n').format(self.piece, self.id, self.surface,
self.begin, self.end)
def __eq__(self, other):
return self.piece == other.piece and self.id == other.id and self.surface == other.surface and self.begin == other.begin and self.end == other.end
def __hash__(self):
return hash(str(self))
__repr__ = __str__
%}
}
%extend sentencepiece::ImmutableSentencePieceText {
const sentencepiece::util::bytes& _text_as_bytes() const {
return $self->text();
}
%rename(_text) text;
%rename(_text_as_bytes) text_as_bytes;
%rename(_score) score;
%rename(_pieces) pieces;
%rename(_pieces_size) pieces_size;
%pythoncode %{
text = property(_text)
text_as_bytes = property(_text_as_bytes)
score = property(_score)
class ImmutableSentencePieceIterator:
def __init__(self, proto):
self.proto = proto
self.len = self.proto._pieces_size()
def __len__(self):
return self.len
def __getitem__(self, index):
if isinstance(index, slice):
return [self.proto._pieces(i) for i in range(self.len)][index.start:index.stop:index.step]
if index < 0:
index = index + self.len
if index < 0 or index >= self.len:
raise IndexError('piece index is out of range')
return self.proto._pieces(index)
def __str__(self):
return '\n'.join(['pieces {{\n{}}}'.format(str(x)) for x in self])
__repr__ = __str__
@property
def pieces(self):
return ImmutableSentencePieceText.ImmutableSentencePieceIterator(self)
def __eq__(self, other):
return self.SerializeAsString() == other.SerializeAsString()
def __hash__(self):
return hash(self.SerializeAsString())
def __str__(self):
return ('text: \"{}\"\n'
'score: {}\n'
'{}').format(self.text, self.score,
'\n'.join(['pieces {{\n{}}}'.format(str(x)) for x in self.pieces]))
__repr__ = __str__
%}
}
%extend sentencepiece::ImmutableNBestSentencePieceText {
%rename(_nbests) nbests;
%rename(_nbests_size) nbests_size;
%pythoncode %{
class ImmutableSentencePieceTextIterator:
def __init__(self, proto):
self.proto = proto
self.len = self.proto._nbests_size()
def __len__(self):
return self.len
def __getitem__(self, index):
if isinstance(index, slice):
return [self.proto._nbests(i) for i in range(self.len)][index.start:index.stop:index.step]
if index < 0:
index = index + self.len
if index < 0 or index >= self.len:
raise IndexError('nbests index is out of range')
return self.proto._nbests(index)
def __str__(self):
return '\n'.join(['nbests {{\n{}}}'.format(str(x)) for x in self])
__repr__ = __str__
@property
def nbests(self):
return ImmutableNBestSentencePieceText.ImmutableSentencePieceTextIterator(self)
def __eq__(self, other):
return self.SerializeAsString() == other.SerializeAsString()
def __hash__(self):
return hash(self.SerializeAsString())
def __str__(self):
return '\n'.join(['nbests {{\n{}}}'.format(str(x)) for x in self.nbests])
__repr__ = __str__
%}
}
%typemap(out) std::vector<int> {
$result = PyList_New($1.size());
for (size_t i = 0; i < $1.size(); ++i) {
PyList_SET_ITEM($result, i, PyInt_FromLong(static_cast<long>($1[i])));
}
}
%typemap(out) std::vector<float> {
$result = PyList_New($1.size());
for (size_t i = 0; i < $1.size(); ++i) {
PyList_SET_ITEM($result, i, PyFloat_FromDouble(static_cast<double>($1[i])));
}
}
%typemap(out) std::vector<std::vector<int>> {
$result = PyList_New($1.size());
for (size_t i = 0; i < $1.size(); ++i) {
PyObject *obj = PyList_New($1[i].size());
for (size_t j = 0; j < $1[i].size(); ++j) {
PyList_SET_ITEM(obj, j, PyInt_FromLong(static_cast<long>($1[i][j])));
}
PyList_SET_ITEM($result, i, obj);
}
}
%typemap(out) std::vector<std::string> {
PyObject *input_type = resultobj;
$result = PyList_New($1.size());
for (size_t i = 0; i < $1.size(); ++i) {
PyList_SET_ITEM($result, i, MakePyOutputString($1[i], input_type));
}
}
%typemap(out) BytesArray {
$result = PyList_New($1.size());
for (size_t i = 0; i < $1.size(); ++i) {
PyList_SET_ITEM($result, i, MakePyOutputBytes($1[i]));
}
}
%typemap(out) std::vector<std::vector<std::string>> {
PyObject *input_type = resultobj;
$result = PyList_New($1.size());
for (size_t i = 0; i < $1.size(); ++i) {
PyObject *obj = PyList_New($1[i].size());
for (size_t j = 0; j < $1[i].size(); ++j) {
PyList_SET_ITEM(obj, j, MakePyOutputString($1[i][j], input_type));
}
PyList_SET_ITEM($result, i, obj);
}
}
%typemap(out) sentencepiece::util::bytes {
$result = MakePyOutputBytes($1);
}
%typemap(out) const sentencepiece::util::bytes& {
$result = MakePyOutputBytes(*$1);
}
%typemap(out) std::string {
PyObject *input_type = resultobj;
$result = MakePyOutputString($1, input_type);
}
%typemap(out) const std::string& {
PyObject *input_type = resultobj;
$result = MakePyOutputString(*$1, input_type);
}
%typemap(out) sentencepiece::util::Status {
if (!$1.ok()) {
SWIG_exception(ToSwigError($1.code()), $1.ToString().c_str());
}
$result = SWIG_From_bool($1.ok());}
%typemap(in) const std::string & {
const PyInputString ustring($input);
if (!ustring.IsAvalable()) {
PyErr_SetString(PyExc_TypeError, "not a string");
SWIG_fail;
}
resultobj = ustring.input_type();
$1 = new std::string(ustring.data(), ustring.size());
}
%typemap(typecheck) absl::string_view = char *;
%typemap(in) absl::string_view {
const PyInputString ustring($input);
if (!ustring.IsAvalable()) {
PyErr_SetString(PyExc_TypeError, "not a string");
SWIG_fail;
}
resultobj = ustring.input_type();
$1 = ustring.str();
}
%typemap(in) const std::vector<absl::string_view>& {
std::vector<absl::string_view> *out = nullptr;
if (PyList_Check($input)) {
const size_t size = PyList_Size($input);
out = new std::vector<absl::string_view>(size);
for (size_t i = 0; i < size; ++i) {
const PyInputString ustring(PyList_GetItem($input, i));
if (ustring.IsAvalable()) {
(*out)[i] = ustring.str();
} else {
PyErr_SetString(PyExc_TypeError, "list must contain strings");
SWIG_fail;
}
resultobj = ustring.input_type();
}
} else {
PyErr_SetString(PyExc_TypeError, "not a list");
SWIG_fail;
}
$1 = out;
}
%typemap(in) const std::vector<int>& {
std::vector<int> *out = nullptr;
if (PyList_Check($input)) {
const size_t size = PyList_Size($input);
out = new std::vector<int>(size);
for (size_t i = 0; i < size; ++i) {
PyObject *o = PyList_GetItem($input, i);
if (PyInt_Check(o)) {
(*out)[i] = static_cast<int>(PyInt_AsLong(o));
} else {
PyErr_SetString(PyExc_TypeError,"list must contain integers");
SWIG_fail;
}
}
} else {
PyErr_SetString(PyExc_TypeError,"not a list");
SWIG_fail;
}
$1 = out;
}
%typemap(in) const std::vector<std::vector<absl::string_view>>& {
std::vector<std::vector<absl::string_view>> *out = nullptr;
if (PyList_Check($input)) {
const size_t size = PyList_Size($input);
out = new std::vector<std::vector<absl::string_view>>(size);
for (size_t i = 0; i < size; ++i) {
PyObject *o = PyList_GetItem($input, i);
if (PyList_Check(o)) {
const size_t size2 = PyList_Size(o);
(*out)[i].resize(size2);
for (size_t j = 0; j < size2; ++j) {
const PyInputString ustring(PyList_GetItem(o, j));
if (ustring.IsAvalable()) {
(*out)[i][j] = ustring.str();
} else {
PyErr_SetString(PyExc_TypeError,"list must contain integers");
SWIG_fail;
}
resultobj = ustring.input_type();
}
} else {
PyErr_SetString(PyExc_TypeError,"not a list");
SWIG_fail;
}
}
} else {
PyErr_SetString(PyExc_TypeError,"not a list");
SWIG_fail;
}
$1 = out;
}
%typemap(in) const std::vector<std::vector<int>>& {
std::vector<std::vector<int>> *out = nullptr;
if (PyList_Check($input)) {
const size_t size = PyList_Size($input);
out = new std::vector<std::vector<int>>(size);
for (size_t i = 0; i < size; ++i) {
PyObject *o = PyList_GetItem($input, i);
if (PyList_Check(o)) {
const size_t size2 = PyList_Size(o);
(*out)[i].resize(size2);
for (size_t j = 0; j < size2; ++j) {
PyObject *o2 = PyList_GetItem(o, j);
if (PyInt_Check(o2)) {
(*out)[i][j] = static_cast<int>(PyInt_AsLong(o2));
} else {
PyErr_SetString(PyExc_TypeError, "list must contain strings");
SWIG_fail;
}
}
} else {
PyErr_SetString(PyExc_TypeError, "not a list");
SWIG_fail;
}
}
} else {
PyErr_SetString(PyExc_TypeError,"not a list");
SWIG_fail;
}
$1 = out;
}
%typemap(in) const std::unordered_map<std::string, std::string> & {
std::unordered_map<std::string, std::string> *out = nullptr;
if (PyDict_Check($input)) {
PyObject *key, *value;
Py_ssize_t pos = 0;
out = new std::unordered_map<std::string, std::string>;
while (PyDict_Next($input, &pos, &key, &value)) {
const PyInputString key_ustring(key);
const PyInputString value_ustring(value);
if (key_ustring.IsAvalable() && value_ustring.IsAvalable()) {
out->emplace(std::string(key_ustring.data(), key_ustring.size()),
std::string(value_ustring.data(), value_ustring.size()));
} else {
PyErr_SetString(PyExc_TypeError, "map must contain strings.");
SWIG_fail;
}
resultobj = key_ustring.input_type();
}
} else {
PyErr_SetString(PyExc_TypeError, "not a dictionary");
SWIG_fail;
}
$1 = out;
}
%typemap(out) std::vector<std::pair<std::vector<std::string>, float>> {
PyObject *input_type = resultobj;
$result = PyList_New($1.size());
for (size_t i = 0; i < $1.size(); ++i) {
PyObject *obj = PyList_New($1[i].first.size());
for (size_t j = 0; j < $1[i].first.size(); ++j) {
PyList_SET_ITEM(obj, j, MakePyOutputString($1[i].first[j], input_type));
}
PyList_SET_ITEM($result, i, PyTuple_Pack(2, obj, PyFloat_FromDouble(static_cast<double>($1[i].second))));
}
}
%typemap(out) std::vector<std::pair<std::vector<int>, float>> {
$result = PyList_New($1.size());
for (size_t i = 0; i < $1.size(); ++i) {
PyObject *obj = PyList_New($1[i].first.size());
for (size_t j = 0; j < $1[i].first.size(); ++j) {
PyList_SET_ITEM(obj, j, PyInt_FromLong(static_cast<long>($1[i].first[j])));
}
PyList_SET_ITEM($result, i, PyTuple_Pack(2, obj, PyFloat_FromDouble(static_cast<double>($1[i].second))));
}
}
%typemap(out) std::vector<sentencepiece::ImmutableSentencePieceText> {
$result = PyList_New($1.size());
for (size_t i = 0; i < $1.size(); ++i) {
PyObject *obj = SWIG_NewPointerObj(new sentencepiece::ImmutableSentencePieceText($1.at(i)), SWIGTYPE_p_sentencepiece__ImmutableSentencePieceText, SWIG_POINTER_OWN | 0);
PyList_SET_ITEM($result, i, obj);
}
}
// Types for normalized string and offset
%typemap(out) std::pair<std::string, std::vector<size_t>> {
PyObject *input_type = resultobj;
if (PyInputString::IsUnicode(input_type)) {
sentencepiece::ConvertToUnicodeAlignment(arg2, $1.first, &$1.second);
}
PyObject *obj = PyList_New($1.second.size());
for (size_t i = 0; i < $1.second.size(); ++i) {
PyList_SET_ITEM(obj, i, PyInt_FromLong(static_cast<long>($1.second[i])));
}
$result = PyTuple_Pack(2, MakePyOutputString($1.first, input_type), obj);
}
%typemap(in) sentencepiece::SentenceIterator * {
sentencepiece::SentenceIterator *out = nullptr;
if (PyIter_Check($input)) {
out = new PySentenceIterator($input);
} else {
PyErr_SetString(PyExc_TypeError, "not a iterator");
SWIG_fail;
}
$1 = out;
}
%typemap(freearg) const std::string& {
delete $1;
}
%typemap(freearg) const std::vector<std::string>& {
delete $1;
}
%typemap(freearg) const std::vector<absl::string_view>& {
delete $1;
}
%typemap(freearg) const std::vector<std::vector<std::string>>& {
delete $1;
}
%typemap(freearg) const std::vector<int>& {
delete $1;
}
%typemap(freearg) const std::vector<float>& {
delete $1;
}
%typemap(freearg) const std::vector<std::vector<int>>& {
delete $1;
}
%typemap(freearg) const std::unordered_map<std::string, std::string> & {
delete $1;
}
%typemap(freearg) sentencepiece::SentenceIterator * {
delete $1;
}
%typemap(freearg) sentencepiece::ImmutableSentencePieceText_ImmutableSentencePiece {
delete $1;
}
%typemap(freearg) sentencepiece::ImmutableSentencePieceText {
delete $1;
}
%typemap(freearg) sentencepiece::ImmutableNBestSentencePieceText {
delete $1;
}
%include <sentencepiece_processor.h>
%include <sentencepiece_trainer.h>
%pythoncode %{
import re
import csv
import sys
import os
from io import StringIO
from io import BytesIO
def _add_snake_case(classname):
"""Added snake_cased method from CammelCased method."""
snake_map = {}
for k, v in classname.__dict__.items():
if re.match(r'^[A-Z]+', k):
snake = re.sub(r'(?<!^)(?=[A-Z])', '_',
k).lower().replace('n_best', 'nbest')
snake_map[snake] = v
for k, v in snake_map.items():
setattr(classname, k, v)
def _batchnize(classname, name):
"""Enables batch request for the method classname.name."""
func = getattr(classname, name, None)
def _func(v, n):
if type(n) is int and (n < 0 or n >= v.piece_size()):
raise IndexError('piece id is out of range.')
return func(v, n)
def _batched_func(self, arg):
if type(arg) is list:
return [_func(self, n) for n in arg]
else:
return _func(self, arg)
setattr(classname, name, _batched_func)
_sentencepiece_processor_init_native = SentencePieceProcessor.__init__
_sentencepiece_normalizer_init_native = SentencePieceNormalizer.__init__
setattr(SentencePieceProcessor, '__init__', SentencePieceProcessor.Init)
setattr(SentencePieceNormalizer, '__init__', SentencePieceNormalizer.Init)
SentencePieceProcessor.Tokenize = SentencePieceProcessor.Encode
SentencePieceProcessor.Detokenize = SentencePieceProcessor.Decode
for m in [
'PieceToId', 'IdToPiece', 'GetScore', 'IsUnknown', 'IsControl', 'IsUnused',
'IsByte'
]:
_batchnize(SentencePieceProcessor, m)
_add_snake_case(SentencePieceProcessor)
_add_snake_case(SentencePieceTrainer)
_add_snake_case(SentencePieceNormalizer)
set_random_generator_seed = SetRandomGeneratorSeed
set_min_log_level = SetMinLogLevel
from ._version import __version__
class _LogStream(object):
def __init__(self, ostream=None):
self.ostream = ostream
if self.ostream is not None:
self.orig_stream_fileno = sys.stderr.fileno()
def __enter__(self):
if self.ostream is not None:
self.orig_stream_dup = os.dup(self.orig_stream_fileno)
os.dup2(self.ostream.fileno(), self.orig_stream_fileno)
def __exit__(self, type, value, traceback):
if self.ostream is not None:
os.close(self.orig_stream_fileno)
os.dup2(self.orig_stream_dup, self.orig_stream_fileno)
os.close(self.orig_stream_dup)
self.ostream.close()
%}