Flax Masked Language Modeling training example (#8728)
* Remove "Model" suffix from Flax models to look more 🤗 Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Initial working (forward + backward) for Flax MLM training example. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Simply code Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Addressing comments, using module and moving to LM task. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Restore parameter name "module" wrongly renamed model. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Restore correct output ordering... Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Actually commit the example 😅 Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Add FlaxBertModelForMaskedLM after rebasing. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make it possible to initialize the training from scratch Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Reuse flax linen example of cross entropy loss Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Added specific data collator for flax Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Remove todo for data collator Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Added evaluation step Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Added ability to provide dtype to support bfloat16 on TPU Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Enable flax tensorboard output Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Enable jax.pmap support. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Ensure batches are correctly sized to be dispatched with jax.pmap Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Enable bfloat16 with --fp16 cmdline args Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Correctly export metrics to tensorboard Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Added dropout and ability to use it. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Effectively enable & disable during training and evaluation steps. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Oops. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Enable specifying kernel initializer scale Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Style. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Added warmup step to the learning rate scheduler. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Fix typo. Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Print training loss Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make style Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * fix linter issue (flake8) Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Fix model matching Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Fix dummies Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Fix non default dtype on Flax models Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Use the same create_position_ids_from_input_ids for FlaxRoberta Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Make Roberta attention as Bert Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * fix copy Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com> * Wording. Co-authored-by: Marc van Zee <marcvanzee@gmail.com> Co-authored-by: Marc van Zee <marcvanzee@gmail.com>
This commit is contained in:
@@ -65,13 +65,12 @@ class FlaxPreTrainedModel(ABC):
|
||||
base_model_prefix = ""
|
||||
model_class = None
|
||||
|
||||
def __init__(self, config: PretrainedConfig, module: nn.Module, params: Dict, seed: int = 0):
|
||||
def __init__(
|
||||
self, config: PretrainedConfig, module: nn.Module, params: Dict, seed: int = 0, dtype: jnp.dtype = jnp.float32
|
||||
):
|
||||
if config is None:
|
||||
raise ValueError("config cannot be None")
|
||||
|
||||
if module is None:
|
||||
raise ValueError("module cannot be None")
|
||||
|
||||
if params is None:
|
||||
raise ValueError("state cannot be None")
|
||||
|
||||
@@ -82,19 +81,23 @@ class FlaxPreTrainedModel(ABC):
|
||||
# Those are public as their type is generic to every derived classes.
|
||||
self.key = PRNGKey(seed)
|
||||
self.params = params
|
||||
self.model = module
|
||||
self.dtype = dtype
|
||||
|
||||
@property
|
||||
def config(self) -> PretrainedConfig:
|
||||
return self._config
|
||||
|
||||
@property
|
||||
def module(self) -> nn.Module:
|
||||
return self._module
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def convert_from_pytorch(pt_state: Dict, config: PretrainedConfig) -> Dict:
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, dtype: jnp.dtype = jnp.float32, *model_args, **kwargs):
|
||||
r"""
|
||||
Instantiate a pretrained Flax model from a pre-trained model configuration.
|
||||
"""
|
||||
@@ -127,6 +130,9 @@ class FlaxPreTrainedModel(ABC):
|
||||
else:
|
||||
model_kwargs = kwargs
|
||||
|
||||
# Add the dtype to model_kwargs
|
||||
model_kwargs["dtype"] = dtype
|
||||
|
||||
# Load model
|
||||
if pretrained_model_name_or_path is not None:
|
||||
if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
||||
|
||||
Reference in New Issue
Block a user