diff --git a/+bert/+internal/convertModelNameToDirectories.m b/+bert/+internal/convertModelNameToDirectories.m index bb1b93a..a9cc200 100644 --- a/+bert/+internal/convertModelNameToDirectories.m +++ b/+bert/+internal/convertModelNameToDirectories.m @@ -2,12 +2,16 @@ % convertModelNameToDirectories Converts the user facing model name to % the directory name used by support files. -% Copyright 2021 The MathWorks, Inc. +% Copyright 2021-2023 The MathWorks, Inc. arguments name (1,1) string end modelName = userInputToSupportFileName(name); -dirpath = {"data","networks","bert",modelName}; +bertBaseLocation = "bert"; +if contains(name,"japanese") + bertBaseLocation = "ja_" + bertBaseLocation; +end +dirpath = {"data","networks",bertBaseLocation,modelName}; end function supportfileName = userInputToSupportFileName(name) @@ -26,5 +30,7 @@ "medium", "uncased_L8_H512_A8"; "small", "uncased_L4_H512_A8"; "mini", "uncased_L4_H256_A4"; - "tiny", "uncased_L2_H128_A2"]; + "tiny", "uncased_L2_H128_A2"; + "japanese-base-wwm", ""; + "japanese-base", ""]; end \ No newline at end of file diff --git a/+bert/+tokenizer/+internal/BasicTokenizer.m b/+bert/+tokenizer/+internal/BasicTokenizer.m index a710aa3..ff7655f 100644 --- a/+bert/+tokenizer/+internal/BasicTokenizer.m +++ b/+bert/+tokenizer/+internal/BasicTokenizer.m @@ -1,7 +1,7 @@ classdef BasicTokenizer < bert.tokenizer.internal.Tokenizer % BasicTokenizer Perform basic tokenization. - % Copyright 2020 The MathWorks, Inc. + % Copyright 2020-2023 The MathWorks, Inc. properties(SetAccess=private) IgnoreCase @@ -28,24 +28,29 @@ function tokens = tokenize(this,text) arguments this (1,1) bert.tokenizer.internal.BasicTokenizer - text (1,1) string + text (1,:) string end - u = textanalytics.unicode.UTF32(text); - u = this.cleanText(u); - u = this.tokenizeCJK(u); - text = u.string(); - if this.IgnoreCase - text = lower(text); - text = textanalytics.unicode.nfd(text); - end - u = textanalytics.unicode.UTF32(text); - cats = u.characterCategories('Granularity','detailed'); - if this.IgnoreCase - [u,cats] = this.stripAccents(u,cats); + tokens = cell(1,numel(string)); + for i = 1:numel(text) + thisText = text(i); + u = textanalytics.unicode.UTF32(thisText); + u = this.cleanText(u); + u = this.tokenizeCJK(u); + thisText = u.string(); + if this.IgnoreCase + thisText = lower(thisText); + thisText = textanalytics.unicode.nfd(thisText); + end + u = textanalytics.unicode.UTF32(thisText); + cats = u.characterCategories('Granularity','detailed'); + if this.IgnoreCase + [u,cats] = this.stripAccents(u,cats); + end + theseTokens = this.splitOnPunc(u,cats); + theseTokens = join(cat(2,theseTokens{:})," "); + theseTokens = this.whiteSpaceTokenize(theseTokens); + tokens{i} = theseTokens; end - tokens = this.splitOnPunc(u,cats); - tokens = join(cat(2,tokens{:})," "); - tokens = this.whiteSpaceTokenize(tokens); end end @@ -160,4 +165,4 @@ inRange(udata,123,126); cats = string(cats); tf = (tf)|(cats.startsWith("P")); -end \ No newline at end of file +end diff --git a/+bert/+tokenizer/+internal/FullTokenizer.m b/+bert/+tokenizer/+internal/FullTokenizer.m index 07acd8d..7055edd 100644 --- a/+bert/+tokenizer/+internal/FullTokenizer.m +++ b/+bert/+tokenizer/+internal/FullTokenizer.m @@ -5,9 +5,16 @@ % using the vocabulary specified in the newline delimited txt file % vocabFile. % - % tokenizer = FullTokenizer(vocabFile,'IgnoreCase',tf) controls if - % the FullTokenizer is case sensitive or not. The default value for - % tf is true. + % tokenizer = FullTokenizer(vocabFile,'PARAM1', VAL1, 'PARAM2', VAL2, ...) + % specifies the optional parameter name/value pairs: + % + % 'BasicTokenizer' - Tokenizer used to split text into words. + % If not specified, a default + % BasicTokenizer is constructed. + % + % 'IgnoreCase' - A logical value to control if the + % FullTokenizer is case sensitive or not. + % The default value is true. % % FullTokenizer methods: % tokenize - tokenize text @@ -15,7 +22,7 @@ % decode - decode encoded tokens % % Example: - % % Save a file named vocab.txt with the text on the next 3 lines: + % % Save a file named fakeVocab.txt with the text on the next 3 lines: % fake % vo % ##cab @@ -30,7 +37,7 @@ % % This returns the encoded form of the tokens - each token is % % replaced by its corresponding line number in the fakeVocab.txt - % Copyright 2021 The MathWorks, Inc. + % Copyright 2021-2023 The MathWorks, Inc. properties(Access=private) Basic @@ -46,9 +53,16 @@ % using the vocabulary specified in the newline delimited txt file % vocabFile. % - % tokenizer = FullTokenizer(vocabFile,'IgnoreCase',tf) controls if - % the FullTokenizer is case sensitive or not. The default value for - % tf is true. + % tokenizer = FullTokenizer(vocabFile,'PARAM1', VAL1, 'PARAM2', VAL2, ...) specifies + % the optional parameter name/value pairs: + % + % 'BasicTokenizer' - Tokenizer used to split text into words. + % If not specified, a default + % BasicTokenizer is constructed. + % + % 'IgnoreCase' - A logical value to control if the + % FullTokenizer is case sensitive or not. + % The default value is true. % % FullTokenizer methods: % tokenize - tokenize text @@ -56,7 +70,7 @@ % decode - decode encoded tokens % % Example: - % % Save a file named vocab.txt with the text on the next 3 lines: + % % Save a file named fakeVocab.txt with the text on the next 3 lines: % fake % vo % ##cab @@ -72,9 +86,16 @@ % % replaced by its corresponding line number in the fakeVocab.txt arguments vocab + nvp.BasicTokenizer = [] nvp.IgnoreCase = true end - this.Basic = bert.tokenizer.internal.BasicTokenizer('IgnoreCase',nvp.IgnoreCase); + if isempty(nvp.BasicTokenizer) + % Default case + this.Basic = bert.tokenizer.internal.BasicTokenizer('IgnoreCase',nvp.IgnoreCase); + else + mustBeA(nvp.BasicTokenizer,'bert.tokenizer.internal.Tokenizer'); + this.Basic = nvp.BasicTokenizer; + end this.WordPiece = bert.tokenizer.internal.WordPieceTokenizer(vocab); this.Encoding = this.WordPiece.Vocab; end @@ -85,12 +106,15 @@ % tokens = tokenize(tokenizer,text) tokenizes the input % string text using the FullTokenizer specified by tokenizer. basicToks = this.Basic.tokenize(txt); - basicToksUnicode = textanalytics.unicode.UTF32(basicToks); - subToks = cell(numel(basicToks),1); - for i = 1:numel(basicToks) - subToks{i} = this.WordPiece.tokenize(basicToksUnicode(i)); + toks = cell(numel(txt),1); + for i = 1:numel(txt) + theseBasicToks = textanalytics.unicode.UTF32(basicToks{i}); + theseSubToks = cell(numel(theseBasicToks),1); + for j = 1:numel(theseBasicToks) + theseSubToks{j} = this.WordPiece.tokenize(theseBasicToks(j)); + end + toks{i} = cat(2,theseSubToks{:}); end - toks = cat(2,subToks{:}); end function idx = encode(this,tokens) @@ -109,4 +133,4 @@ tokens = this.Encoding.ind2word(x); end end -end \ No newline at end of file +end diff --git a/+bert/+tokenizer/+internal/TokenizedDocumentTokenizer.m b/+bert/+tokenizer/+internal/TokenizedDocumentTokenizer.m new file mode 100644 index 0000000..244aeeb --- /dev/null +++ b/+bert/+tokenizer/+internal/TokenizedDocumentTokenizer.m @@ -0,0 +1,36 @@ +classdef TokenizedDocumentTokenizer < bert.tokenizer.internal.Tokenizer + % TokenizedDocumentTokenizer Implements a word-level tokenizer using + % tokenizedDocument. + + % Copyright 2023 The MathWorks, Inc. + + properties + TokenizedDocumentOptions + IgnoreCase + end + + methods + function this = TokenizedDocumentTokenizer(varargin,args) + arguments(Repeating) + varargin + end + arguments + args.IgnoreCase (1,1) logical = true + end + this.IgnoreCase = args.IgnoreCase; + this.TokenizedDocumentOptions = varargin; + end + + function toks = tokenize(this,txt) + arguments + this + txt (1,:) string + end + if this.IgnoreCase + txt = lower(txt); + end + t = tokenizedDocument(txt,this.TokenizedDocumentOptions{:}); + toks = doc2cell(t); + end + end +end \ No newline at end of file diff --git a/+bert/+tokenizer/BERTTokenizer.m b/+bert/+tokenizer/BERTTokenizer.m index 68e76d3..2f405ea 100644 --- a/+bert/+tokenizer/BERTTokenizer.m +++ b/+bert/+tokenizer/BERTTokenizer.m @@ -9,9 +9,16 @@ % case-insensitive BERTTokenizer using the file vocabFile as % the vocabulary. % - % tokenizer = BERTTokenizer(vocabFile,'IgnoreCase',tf) - % Constructs a BERTTokenizer which is case-sensitive or not - % according to the scalar logical tf. The default is true. + % tokenizer = BERTTokenizer(vocabFile,'PARAM1', VAL1, 'PARAM2', VAL2, ...) + % specifies the optional parameter name/value pairs: + % + % 'IgnoreCase' - A logical value to control if the + % BERTTokenizer is case sensitive or not. + % The default value is true. + % + % 'FullTokenizer' - The underlying word-piece tokenizer. + % If not specified, a default + % FullTokenizer is constructed. % % BERTTokenizer properties: % FullTokenizer - The underlying word-piece tokenizer. @@ -34,7 +41,7 @@ % tokenizer = bert.tokenizer.BERTTokenizer(); % sequences = tokenizer.encode("Hello World!") - % Copyright 2021 The MathWorks, Inc. + % Copyright 2021-2023 The MathWorks, Inc. properties(Constant) PaddingToken = "[PAD]" @@ -63,9 +70,16 @@ % case-insensitive BERTTokenizer using the file vocabFile as % the vocabulary. % - % tokenizer = BERTTokenizer(vocabFile,'IgnoreCase',tf) - % Constructs a BERTTokenizer which is case-sensitive or not - % according to the scalar logical tf. The default is true. + % tokenizer = BERTTokenizer(vocabFile,'PARAM1', VAL1, 'PARAM2', VAL2, ...) + % specifies the optional parameter name/value pairs: + % + % 'IgnoreCase' - A logical value to control if the + % BERTTokenizer is case sensitive or not. + % The default value is true. + % + % 'FullTokenizer' - The underlying word-piece tokenizer. + % If not specified, a default + % FullTokenizer is constructed. % % BERTTokenizer properties: % FullTokenizer - The underlying word-piece tokenizer. @@ -90,9 +104,15 @@ arguments vocabFile (1,1) string {mustBeFile} = bert.internal.getSupportFilePath("base","vocab.txt") nvp.IgnoreCase (1,1) logical = true + nvp.FullTokenizer = [] + end + if isempty(nvp.FullTokenizer) + ignoreCase = nvp.IgnoreCase; + this.FullTokenizer = bert.tokenizer.internal.FullTokenizer(vocabFile,'IgnoreCase',ignoreCase); + else + mustBeA(nvp.FullTokenizer,'bert.tokenizer.internal.FullTokenizer'); + this.FullTokenizer = nvp.FullTokenizer; end - ignoreCase = nvp.IgnoreCase; - this.FullTokenizer = bert.tokenizer.internal.FullTokenizer(vocabFile,'IgnoreCase',ignoreCase); this.PaddingCode = this.FullTokenizer.encode(this.PaddingToken); this.SeparatorCode = this.FullTokenizer.encode(this.SeparatorToken); this.StartCode = this.FullTokenizer.encode(this.StartToken); @@ -131,10 +151,9 @@ inputShape = size(text_a); text_a = reshape(text_a,[],1); text_b = reshape(text_b,[],1); - tokenize = @(text) this.FullTokenizer.tokenize(text); - tokens = arrayfun(tokenize,text_a,'UniformOutput',false); + tokens = this.FullTokenizer.tokenize(text_a); if ~isempty(text_b) - tokens_b = arrayfun(tokenize,text_b,'UniformOutput',false); + tokens_b = this.FullTokenizer.tokenize(text_b); tokens = cellfun(@(tokens_a,tokens_b) [tokens_a,this.SeparatorToken,tokens_b], tokens, tokens_b, 'UniformOutput', false); end tokens = cellfun(@(tokens) [this.StartToken, tokens, this.SeparatorToken], tokens, 'UniformOutput', false); @@ -218,4 +237,4 @@ text = cellfun(@(x) join(x," "), tokens); end end -end \ No newline at end of file +end diff --git a/FineTuneBERTJapanese.m b/FineTuneBERTJapanese.m new file mode 100644 index 0000000..bd21416 --- /dev/null +++ b/FineTuneBERTJapanese.m @@ -0,0 +1,359 @@ +%% Fine-Tune Pretrained BERT Model +% This example shows how to fine-tune a pretrained BERT model for text +% classification. +% +% To get the most out of a pretrained BERT model, you can retrain and +% fine-tune the BERT parameters weights for your task. +% +% This example shows how to fine-tune a pretrained BERT model to classify +% failure events given a data set of factory reports. + +%% Load Pretrained BERT Model +% Load a pretrained BERT model using the |bert| function. The model +% consists of a tokenizer that encodes text as sequences of integers, and a +% structure of parameters. +mdl = bert(Model="japanese-base"); + +%% +% View the BERT model tokenizer. The tokenizer encodes text as sequences of +% integers and holds the details of padding, start, separator and mask +% tokens. +tokenizer = mdl.Tokenizer + +%% Load Data +% Load the example data. The file |factoryReportsJP.csv| contains factory +% reports, including a text description and categorical labels for each +% event. +% The table contains these variables: +% Var1 — Description +% Var2 — Category +% Var3 — Urgency +% Var4 — Resolution +% Var5 — Cost + +filename = "factoryReportsJP.csv"; +data = readtable(filename,"TextType","string","ReadVariableNames",false); +data.Properties.VariableNames = ["Description", "Category", ... + "Urgency", "Resolution", "Cost"]; +head(data) + + +%% +% The goal of this example is to classify events by the label in the +% |Category| column. To divide the data into classes, convert these labels +% to categorical. +data.Category = categorical(data.Category); + +%% +% View the number of classes. +classes = categories(data.Category); +numClasses = numel(classes) + +%% +% View the distribution of the classes in the data using a histogram. +figure +histogram(data.Category); +xlabel("Class") +ylabel("Frequency") +title("Class Distribution") + +%% +% Encode the text data using the BERT model tokenizer using the |encode| +% function and add the tokens to the training data table. +data.Tokens = encode(tokenizer, data.Description); + +%% +% The next step is to partition it into sets for training and validation. +% Partition the data into a training partition and a held-out partition for +% validation and testing. Specify the holdout percentage to be 20%. +cvp = cvpartition(data.Category,"Holdout",0.2); +dataTrain = data(training(cvp),:); +dataValidation = data(test(cvp),:); + +%% +% View the number of training and validation observations. +numObservationsTrain = size(dataTrain,1) +numObservationsValidation = size(dataValidation,1) + +%% +% Extract the training text data, labels, and encoded BERT tokens from the +% partitioned tables. +textDataTrain = dataTrain.Description; +TTrain = dataTrain.Category; +tokensTrain = dataTrain.Tokens; + +%% +% To check that you have imported the data correctly, visualize the +% training text data using a word cloud. + +figure +wordcloud(textDataTrain); +title("Training Data") + +%% Prepare Data for Training +% Convert the documents to feature vectors using the BERT model as a +% feature extractor. + +% To extract the features of the training data by iterating over +% mini-batches, create a |minibatchqueue| object. + +% Mini-batch queues require a single datastore that outputs both the +% predictors and responses. Create array datastores containing the training +% BERT tokens and labels and combine them using the |combine| function. +dsXTrain = arrayDatastore(tokensTrain,"OutputType","same"); +dsTTrain = arrayDatastore(TTrain); +cdsTrain = combine(dsXTrain,dsTTrain); + +%% Initialize Model Parameters +% Initialize the weights for the classifier to apply after the BERT +% embedding. +outputSize = mdl.Parameters.Hyperparameters.HiddenSize; +mdl.Parameters.Weights.classifier.kernel = dlarray(randn(numClasses, outputSize)); +mdl.Parameters.Weights.classifier.bias = dlarray(zeros(numClasses, 1)); + +%% Specify Training Options +% Train for 4 epochs with a mini-batch size of 32. Train with a learning +% rate of 0.00001. +numEpochs = 4; +miniBatchSize = 32; +learnRate = 1e-5; + +%% Train Model +% Fine tune the model parameters using a custom training loop. + +%% +% Create a mini-batch queue for the training data. Preprocess the +% mini-batches using the |preprocessMiniBatch| function, listed at the end +% of the example and discard any partial mini-batches. +paddingValue = mdl.Tokenizer.PaddingCode; +maxSequenceLength = mdl.Parameters.Hyperparameters.NumContext; + +mbqTrain = minibatchqueue(cdsTrain,2,... + "MiniBatchSize",miniBatchSize, ... + "MiniBatchFcn",@(X,Y) preprocessMiniBatch(X,Y,paddingValue,maxSequenceLength), ... + "PartialMiniBatch","discard"); + +%% +% Initialize training progress plot. +% In 23a you can use trainingProgressMonitor +figure +C = colororder; +lineLossTrain = animatedline("Color",C(2,:)); + +ylim([0 inf]); +xlabel("Iteration"); +ylabel("Loss"); + +%% +% Initialize parameters for the Adam optimizer. +trailingAvg = []; +trailingAvgSq = []; + +%% +% Extract the model parameters from the pretrained BERT model. +parameters = mdl.Parameters; + +%% +% Train the model using a custom training loop. +% +% For each epoch, shuffle the mini-batch queue and loop over mini-batches +% of data. At the end of each iteration, update the training progress plot. +% +% For each iteration: +% * Read a mini-batch of data from the mini-batch queue. +% * Evaluate the model gradients and loss using the |dlfeval| and +% |modelGradients| functions. +% * Update the network parameters using the |adamupdate| function. +% * Update the training plot. + +iteration = 0; +start = tic; + +% Loop over epochs. +for epoch = 1:numEpochs + + % Shuffle data. + shuffle(mbqTrain); + + % Loop over mini-batches + while hasdata(mbqTrain) + iteration = iteration + 1; + + % Read mini-batch of data. + [X,T] = next(mbqTrain); + + % Evaluate loss and gradients. + [loss,gradients] = dlfeval(@modelGradients,X,T,parameters); + + % Update model parameters. + [parameters.Weights,trailingAvg,trailingAvgSq] = adamupdate(parameters.Weights,gradients, ... + trailingAvg,trailingAvgSq,iteration,learnRate); + + % Update training plot. + loss = double(gather(extractdata(loss))); + addpoints(lineLossTrain,iteration,loss); + + D = duration(0,0,toc(start),'Format','hh:mm:ss'); + title("Epoch: " + epoch + ", Elapsed: " + string(D)) + drawnow + end +end + +%% Test Network +% Test the network using the held-out validation data. + +%% +% Extract the encoded tokens and labels from the validation data table. +tokensValidation = dataValidation.Tokens; +TValidation = dataValidation.Category; + +%% +% Create an array datastore containing the encoded tokens. +dsXValidation = arrayDatastore(tokensValidation,"OutputType","same"); + +%% +% Create a mini-batch queue for the validation data. Preprocess the +% mini-batches using the |preprocessPredictors| function, listed at the end +% of the example. +mbqValidation = minibatchqueue(dsXValidation,1,... + "MiniBatchSize",miniBatchSize, ... + "MiniBatchFcn",@(X) preprocessPredictors(X,paddingValue,maxSequenceLength)); + +%% +% Make predictions using the |modelPredictions| function, listed at the end +% of the example, and display the results in a confusion matrix. +YPredValidation = modelPredictions(parameters,mbqValidation,classes); + +figure +confusionchart(TValidation,YPredValidation) + +%% Predict Using New Data +% Classify the event type of three new reports. + +%% +% Create a string array containing the new reports. +reportsNew = [ + "クーラントがソーターの下に溜まっています。" + "ソーターは起動時にヒューズを飛ばします。" + "アセンブラから非常に大きなガタガタという音が聞こえます。"]; + +%% +% Encode the text data as sequences of tokens using the BERT model +% tokenizer. +tokensNew = encode(tokenizer, reportsNew); + +%% +% Create a mini-batch queue for the new data. Preprocess the mini-batches +% using the |preprocessPredictors| function, listed at the end of the +% example. +dsXNew = arrayDatastore(tokensNew,"OutputType","same"); + +mbqNew = minibatchqueue(dsXNew,1,... + "MiniBatchSize",miniBatchSize, ... + "MiniBatchFcn",@(X) preprocessPredictors(X,paddingValue,maxSequenceLength)); + +%% +% Make predictions using the |modelPredictions| function, listed at the end +% of the example. +YPredNew = modelPredictions(parameters,mbqNew,classes) + +%% Supporting Functions + +%%% Mini-batch Preprocessing Function. +% The |preprocessMiniBatch| function preprocess the predictors using the +% |preprocessPredictors| function and then encodes the labels as encoded +% vectors. Use this preprocessing function to preprocess both predictors +% and labels. +function [X,T] = preprocessMiniBatch(X,T,paddingValue,maxSequenceLength) + +X = preprocessPredictors(X,paddingValue,maxSequenceLength); +T = cat(2,T{:}); +T = onehotencode(T,1); + +end + +%%% Predictors Preprocessing Functions +% The |preprocessPredictors| function truncates the mini-batches to have +% the specified maximum sequence length, pads the sequences to have the +% same length. Use this preprocessing function to preprocess the predictors +% only. +function X = preprocessPredictors(X,paddingValue,maxSeqLen) + +X = truncateSequences(X,maxSeqLen,SeparatorCode=4); +X = padsequences(X,2,"PaddingValue",paddingValue); + +end + +%%% BERT Embedding Function +% The |bertEmbed| function maps input data to embedding vectors and +% optionally applies dropout using the "DropoutProbability" name-value +% pair. +function Y = bertEmbed(X,parameters,args) + +arguments + X + parameters + args.DropoutProbability = 0 +end + +dropoutProbabilitiy = args.DropoutProbability; + +Y = bert.model(X,parameters, ... + "DropoutProb",dropoutProbabilitiy, ... + "AttentionDropoutProb",dropoutProbabilitiy); + +% To return single feature vectors, return the first element. +Y = Y(:,1,:); +Y = squeeze(Y); + +end + +%%% Model Function +% The function |model| performs a forward pass of the classification model. +function Y = model(X,parameters,dropout) + +Y = bertEmbed(X,parameters,"DropoutProbability",dropout); + +weights = parameters.Weights.classifier.kernel; +bias = parameters.Weights.classifier.bias; +Y = fullyconnect(Y,weights,bias,"DataFormat","CB"); + +end + +%%% Model Gradients Function +% The |modelGradients| function performs a forward pass of the +% classification model and returns the model loss and gradients of the loss +% with respect to the learnable parameters. +function [loss,gradients] = modelGradients(X,T,parameters) + +dropout = 0.1; +Y = model(X,parameters,dropout); +Y = softmax(Y,"DataFormat","CB"); +loss = crossentropy(Y,T,"DataFormat","CB"); +gradients = dlgradient(loss,parameters.Weights); + +end + +%%% Model Predictions Function +% The |modelPredictions| function makes predictions by iterating over +% mini-batches of data. +function predictions = modelPredictions(parameters,mbq,classes) + +predictions = []; + +dropout = 0; + +reset(mbq); + +while hasdata(mbq) + + dlX = next(mbq); + dlYPred = model(dlX,parameters,dropout); + dlYPred = softmax(dlYPred,"DataFormat","CB"); + + YPred = onehotdecode(dlYPred,classes,1)'; + + predictions = [predictions; YPred]; +end + +end \ No newline at end of file diff --git a/README.md b/README.md index a4bd568..407b52a 100644 --- a/README.md +++ b/README.md @@ -29,6 +29,8 @@ Download or [clone](https://www.mathworks.com/help/matlab/matlab_prog/use-source - `"small"` - A 4 layer model with hidden size 512. - `"mini"` - A 4 layer model with hidden size 256. - `"tiny"` - A 2 layer model with hidden size 128. +- `"japanese-base"` - A 12 layer model with hidden size 768, pretrained on texts in the Japanese language. +- `"japanese-base-wwm"` - A 12 layer model with hidden size 768, pretrained on texts in the Japanese language. Additionally, the model is trained with the whole word masking enabled for the masked language modeling (MLM) objective. ### bert.model `Z = bert.model(X,parameters)` performs inference with a BERT model on the input `1`-by-`numInputTokens`-by-`numObservations` array of encoded tokens with the specified parameters. The output `Z` is an array of size (`NumHeads*HeadSize`)-by-`numInputTokens`-by-`numObservations`. The element `Z(:,i,j)` corresponds to the BERT embedding of input token `X(1,i,j)`. diff --git a/bert.m b/bert.m index b0b4956..2231ec4 100644 --- a/bert.m +++ b/bert.m @@ -5,28 +5,61 @@ % % mdl = bert('Model', modelName) loads the BERT model specified by % modelName. Supported values for modelName are "base" (default), -% "multilingual-cased","medium","small","mini", and "tiny". +% "multilingual-cased","medium","small","mini", "tiny", "japanese-base", +% and "japanese-base-wwm" -% Copyright 2021 The MathWorks, Inc. +% Copyright 2021-2023 The MathWorks, Inc. arguments nvp.Model (1,1) string {mustBeMember(nvp.Model,[ "base" "multilingual-cased" "medium" "small" - "mini" - "tiny"])} = "base" + "mini" + "tiny" + "japanese-base" + "japanese-base-wwm"])} = "base" end -% Download the license file -bert.internal.getSupportFilePath(nvp.Model,"bert.RIGHTS"); -params = bert.load(nvp.Model); -% Get the IgnoreCase hyperparameter, then remove it, downstream code -% shouldn't need it. -ignoreCase = params.Hyperparameters.IgnoreCase; -% Get vocab file -vocabFile = bert.internal.getSupportFilePath(nvp.Model,"vocab.txt"); -params.Hyperparameters = rmfield(params.Hyperparameters,'IgnoreCase'); + +switch nvp.Model + case "japanese-base" + mdl = iJapaneseBERTModel("japanese-base", "bert-base-japanese.zip"); + case "japanese-base-wwm" + mdl = iJapaneseBERTModel("japanese-base-wwm", "bert-base-japanese-whole-word-masking.zip"); + otherwise + % Download the license file + bert.internal.getSupportFilePath(nvp.Model,"bert.RIGHTS"); + params = bert.load(nvp.Model); + % Get the IgnoreCase hyperparameter, then remove it, downstream code + % shouldn't need it. + ignoreCase = params.Hyperparameters.IgnoreCase; + % Get vocab file + vocabFile = bert.internal.getSupportFilePath(nvp.Model,"vocab.txt"); + params.Hyperparameters = rmfield(params.Hyperparameters,'IgnoreCase'); + mdl = struct(... + 'Tokenizer',bert.tokenizer.BERTTokenizer(vocabFile,'IgnoreCase',ignoreCase),... + 'Parameters',params); +end +end + +function mdl = iJapaneseBERTModel(modelName, zipFileName) +zipFilePath = bert.internal.getSupportFilePath(modelName, zipFileName); +modelDir = fullfile(fileparts(zipFilePath), replace(zipFileName, ".zip", "")); +unzip(zipFilePath, modelDir); +% Build the tokenizer +btok = bert.tokenizer.internal.TokenizedDocumentTokenizer("Language","ja","TokenizeMethod","mecab",IgnoreCase=false); +vocabFile = fullfile(modelDir, "vocab.txt"); +ftok = bert.tokenizer.internal.FullTokenizer(vocabFile,BasicTokenizer=btok); +tok = bert.tokenizer.BERTTokenizer(vocabFile,FullTokenizer=ftok); +% Build the model +params.Weights = load(fullfile(modelDir, "weights.mat")); +params.Weights = dlupdate(@dlarray,params.Weights); +params.Hyperparameters = struct(... + NumHeads=12,... + NumLayers=12,... + NumContext=512,... + HiddenSize=768); mdl = struct(... - 'Tokenizer',bert.tokenizer.BERTTokenizer(vocabFile,'IgnoreCase',ignoreCase),... - 'Parameters',params); + Tokenizer=tok,... + Parameters=params); end \ No newline at end of file diff --git a/test/bert/tokenizer/internal/tBasicTokenizer.m b/test/bert/tokenizer/internal/tBasicTokenizer.m index b761974..2d9ff1e 100644 --- a/test/bert/tokenizer/internal/tBasicTokenizer.m +++ b/test/bert/tokenizer/internal/tBasicTokenizer.m @@ -1,7 +1,7 @@ classdef tBasicTokenizer < matlab.unittest.TestCase % tBasicTokenizer Unit tests for the BasicTokenizer - % Copyright 2021 The MathWorks, Inc. + % Copyright 2021-2023 The MathWorks, Inc. methods(Test) function canConstruct(test) @@ -13,10 +13,19 @@ function canConstruct(test) function canTokenize(test) tok = bert.tokenizer.internal.BasicTokenizer(); str = "foo bar baz"; - exp_out = ["foo","bar","baz"]; + exp_out = {["foo","bar","baz"]}; act_out = tok.tokenize(str); test.verifyEqual(act_out,exp_out); end + + function canTokenizeBatch(test) + tok = bert.tokenizer.internal.BasicTokenizer(); + manyStrs = repmat("foo bar baz",1,20); + act_out = tokenize(tok, manyStrs); + exp_out = arrayfun(@(str) tokenize(tok,str),manyStrs,UniformOutput=false); + exp_out = [exp_out{:}]; + test.verifyEqual(act_out,exp_out); + end function removesControlCharactersAndWhitespace(test) tok = bert.tokenizer.internal.BasicTokenizer(); @@ -26,7 +35,7 @@ function removesControlCharactersAndWhitespace(test) words = ["Testing","a","blah"]; str = strcat(words(1)," ",aFormatChar," ",... words(2),aFormatChar," ",aControlChar," ",words(3),aSpaceChar); - exp_out = [lower(words(1)),words(2),words(3)]; + exp_out = {[lower(words(1)),words(2),words(3)]}; act_out = tok.tokenize(str); test.verifyEqual(act_out,exp_out); end @@ -36,7 +45,7 @@ function splitsOnNewlines(test) tok = bert.tokenizer.internal.BasicTokenizer(); str = "hello"+newline+"world"; act_toks = tok.tokenize(str); - exp_toks = ["hello","world"]; + exp_toks = {["hello","world"]}; test.verifyEqual(act_toks,exp_toks); end @@ -45,8 +54,8 @@ function tokenizesCJK(test) str = strcat(... compose("Arbitrary \x4E01\x4E02 CJK chars \xD869\xDF00\xD86D\xDE3E"),... "more"); - exp_out = ["arbitrary",compose("\x4E01"),compose("\x4E02"),"cjk","chars",... - compose("\xD869\xDF00"),compose("\xD86D\xDE3E"),"more"]; + exp_out = {["arbitrary",compose("\x4E01"),compose("\x4E02"),"cjk","chars",... + compose("\xD869\xDF00"),compose("\xD86D\xDE3E"),"more"]}; act_out = tok.tokenize(str); test.verifyEqual(act_out,exp_out); end @@ -54,7 +63,7 @@ function tokenizesCJK(test) function splitsOnPunctuation(test) tok = bert.tokenizer.internal.BasicTokenizer(); str = "hello. hello, world? hello world! hello"; - exp_out = ["hello",".","hello",",","world","?","hello","world","!","hello"]; + exp_out = {["hello",".","hello",",","world","?","hello","world","!","hello"]}; act_out = tok.tokenize(str); test.verifyEqual(act_out,exp_out); end @@ -62,7 +71,7 @@ function splitsOnPunctuation(test) function stripsAccents(test) tok = bert.tokenizer.internal.BasicTokenizer(); str = compose("h\x00E9llo"); - exp_out = "hello"; + exp_out = {"hello"}; act_out = tok.tokenize(str); test.verifyEqual(act_out,exp_out); end @@ -70,9 +79,9 @@ function stripsAccents(test) function canBeCaseSensitive(test) tok = bert.tokenizer.internal.BasicTokenizer('IgnoreCase',false); str = "FOO bAr baz"; - exp_out = ["FOO","bAr","baz"]; + exp_out = {["FOO","bAr","baz"]}; act_out = tok.tokenize(str); test.verifyEqual(act_out,exp_out); end end -end \ No newline at end of file +end diff --git a/test/bert/tokenizer/internal/tFullTokenizer.m b/test/bert/tokenizer/internal/tFullTokenizer.m index f4d31a5..165f890 100644 --- a/test/bert/tokenizer/internal/tFullTokenizer.m +++ b/test/bert/tokenizer/internal/tFullTokenizer.m @@ -1,8 +1,8 @@ classdef(SharedTestFixtures = { - DownloadBERTFixture}) tFullTokenizer < matlab.unittest.TestCase + DownloadBERTFixture}) tFullTokenizer < matlab.mock.TestCase % tFullTokenizer Unit tests for the FullTokenizer. - % Copyright 2021 The MathWorks, Inc. + % Copyright 2021-2023 The MathWorks, Inc. methods(Test) function matchesExpectedTokenization(test) @@ -12,9 +12,24 @@ function matchesExpectedTokenization(test) % Create a string to tokenize. str = "UNwant"+compose("\x00E9")+"d,running."; - exp_toks = ["unwanted",",","running","."]; + exp_toks = {["unwanted",",","running","."]}; act_toks = tok.tokenize(str); test.verifyEqual(act_toks,exp_toks); end + + function errorsIfBasicTokenizerIsNotTokenizer(test) + vocabFile = bert.internal.getSupportFilePath("base","vocab.txt"); + makeTok = @() bert.tokenizer.internal.FullTokenizer(vocabFile,Basic=vocabFile); + test.verifyError(makeTok,"MATLAB:validators:mustBeA"); + end + + function canSetBasicTokenizer(test) + [mock,behaviour] = test.createMock(?bert.tokenizer.internal.Tokenizer); + test.assignOutputsWhen(withAnyInputs(behaviour.tokenize),"hello"); + vocabFile = bert.internal.getSupportFilePath("base","vocab.txt"); + tok = bert.tokenizer.internal.FullTokenizer(vocabFile,BasicTokenizer=mock); + toks = tok.tokenize("anything"); + test.verifyEqual(toks,{"hello"}); %#ok + end end end \ No newline at end of file diff --git a/test/bert/tokenizer/internal/tTokenizedDocumentTokenizer.m b/test/bert/tokenizer/internal/tTokenizedDocumentTokenizer.m new file mode 100644 index 0000000..32595d2 --- /dev/null +++ b/test/bert/tokenizer/internal/tTokenizedDocumentTokenizer.m @@ -0,0 +1,30 @@ +classdef tTokenizedDocumentTokenizer < matlab.unittest.TestCase + % tTokenizedDocumentTokenizer Unit tests for TokenizedDocumentTokenizer. + + % Copyright 2023 The MathWorks, Inc. + + methods(Test) + function tokenizationMatchesTokenizedDocument(test) + % TokenizedDocumentTokenizer does what it says on the tin - + % uses tokenizedDocument. + tok = bert.tokenizer.internal.TokenizedDocumentTokenizer; + str = "a random string. doesn't matter."; + toks = tok.tokenize(str); + doc = tokenizedDocument(str); + toksExp = {string(doc)}; + test.verifyEqual(toks,toksExp); + end + + function canSetOptions(test) + % We can pass in tokenization options matching + % tokenizedDocument's NVPs. + customToken = "foo bar"; + tok = bert.tokenizer.internal.TokenizedDocumentTokenizer(CustomTokens=customToken); + str = "in this case "+customToken+" is one token."; + toks = tok.tokenize(str); + import matlab.unittest.constraints.AnyElementOf + import matlab.unittest.constraints.IsEqualTo + test.verifyThat(AnyElementOf(toks{1}),IsEqualTo(customToken)); + end + end +end \ No newline at end of file diff --git a/test/bert/tokenizer/tBERTTokenizerForJP.m b/test/bert/tokenizer/tBERTTokenizerForJP.m new file mode 100644 index 0000000..080e69b --- /dev/null +++ b/test/bert/tokenizer/tBERTTokenizerForJP.m @@ -0,0 +1,64 @@ +classdef(SharedTestFixtures = { + DownloadJPBERTFixture}) tBERTTokenizerForJP < matlab.unittest.TestCase + % tBERTTokenizerForJP Unit tests for the BERTTokenizer using Japanese + % BERT models. + + % Copyright 2023 The MathWorks, Inc. + + properties(TestParameter) + VocabFiles = iVocabFiles() + end + + properties(Constant) + Constructor = @iJapaneseTokenizerConstructor + end + + methods(Test) + + function hasExpectedProperties(test, VocabFiles) + tok = test.Constructor(VocabFiles); + test.verifyEqual(tok.PaddingToken, "[PAD]"); + test.verifyEqual(tok.StartToken, "[CLS]"); + test.verifyEqual(tok.SeparatorToken, "[SEP]"); + test.verifyEqual(tok.PaddingCode, 1); + test.verifyEqual(tok.SeparatorCode, 4); + test.verifyEqual(tok.StartCode, 3); + end + + function matchesExpectedEncoding(test, VocabFiles) + tok = test.Constructor(VocabFiles); + text = "月夜の想い、謎めく愛。君の謎。"; + expectedEncoding = [3 38 29340 6 12385 7 5939 2088 28504 768 ... + 9 2607 6 5939 9 4]; + y = tok.encode(text); + test.verifyClass(y,'cell'); + y1 = y{1}; + test.verifyEqual(y1(1),tok.StartCode); + test.verifyEqual(y1(end),tok.SeparatorCode); + test.verifyEqual(y1,expectedEncoding); + end + end +end + +function modelNames = iModelNames +% struct friendly model names +modelNames = ["japanese_base", "japanese_base_wwm"]; +end + +function vocabFiles = iVocabFiles +modelDir = ["bert-base-japanese", "bert-base-japanese-whole-word-masking"]; +modelNames = iModelNames; +vocabFiles = struct(); +for i = 1:numel(modelNames) + versionName = modelDir(i); + vocabDir = fullfile("data", "networks", "ja_bert", versionName, "vocab.txt"); + model = modelNames(i); + vocabFiles.(replace(model, "-", "_")) = fullfile(matlab.internal.examples.utils.getSupportFileDir(),"nnet",vocabDir); +end +end + +function japaneseBERTTokenizer = iJapaneseTokenizerConstructor(vocabLocation) +btok = bert.tokenizer.internal.TokenizedDocumentTokenizer("Language","ja","TokenizeMethod","mecab",IgnoreCase=false); +ftok = bert.tokenizer.internal.FullTokenizer(vocabLocation,BasicTokenizer=btok); +japaneseBERTTokenizer = bert.tokenizer.BERTTokenizer(vocabLocation,FullTokenizer=ftok); +end \ No newline at end of file diff --git a/test/tbert.m b/test/tbert.m index bb53c10..40cea43 100644 --- a/test/tbert.m +++ b/test/tbert.m @@ -1,5 +1,5 @@ classdef(SharedTestFixtures = { - DownloadBERTFixture}) tbert < matlab.unittest.TestCase + DownloadBERTFixture, DownloadJPBERTFixture}) tbert < matlab.unittest.TestCase % tbert System level tests for bert % Copyright 2021 The MathWorks, Inc. @@ -7,13 +7,22 @@ properties(TestParameter) UncasedVersion = {"base", ... "tiny"} + AllModels = {"base","multilingual-cased","medium",... + "small","mini","tiny","japanese-base",... + "japanese-base-wwm"} end methods(Test) + function canConstructModelWithDefault(test) % Verify the default model can be constructed. test.verifyWarningFree(@() bert()); end + + function canConstructAllModels(test, AllModels) + % Verify the all available models can be constructed. + test.verifyWarningFree(@() bert('Model', AllModels)); + end function canConstructModelWithNVPAndVerifyDefault(test) % Verify the default model matches the default model. diff --git a/test/tools/DownloadJPBERTFixture.m b/test/tools/DownloadJPBERTFixture.m new file mode 100644 index 0000000..deec417 --- /dev/null +++ b/test/tools/DownloadJPBERTFixture.m @@ -0,0 +1,48 @@ +classdef DownloadJPBERTFixture < matlab.unittest.fixtures.Fixture + % DownloadJPBERTFixture A fixture for downloading the Japanese BERT models and + % clearing them out after tests finish if they were not previously + % downloaded. + + % Copyright 2023 The MathWorks, Inc + + properties(Constant) + Models = dictionary(["japanese-base", "japanese-base-wwm"], ... + ["bert-base-japanese", "bert-base-japanese-whole-word-masking"]); + end + + properties + DataDirExists + end + + methods + function setup(this) + dirs = this.pathToSupportFile(this.Models.values); + dataDirsExist = arrayfun(@(dir) exist(dir,'dir')==7, dirs); + this.DataDirExists = dictionary(this.Models.keys,dataDirsExist); + modelNames = this.Models.keys; + for i=1:numel(modelNames) + model = modelNames(i); + if ~this.DataDirExists(model) + bert('Model',model); + end + end + end + + function teardown(this) + modelNames = this.Models.keys; + for i=1:numel(modelNames) + model = modelNames(i); + if ~this.DataDirExists(model) + rmdir(this.pathToSupportFile(this.Models(model)),'s'); + end + end + end + end + + methods(Access=private) + function path = pathToSupportFile(~,model) + modelDir = fullfile("data", "networks", "ja_bert", model); + path = fullfile(matlab.internal.examples.utils.getSupportFileDir(),"nnet",modelDir); + end + end +end \ No newline at end of file