Use Python 3.9 syntax in examples (#37279)

Signed-off-by: cyy <cyyever@outlook.com>
This commit is contained in:
cyyever
2025-04-07 19:52:21 +08:00
committed by GitHub
parent 08f36771b3
commit 0fb8d49e88
123 changed files with 358 additions and 451 deletions

View File

@@ -1,5 +1,4 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -27,7 +26,7 @@ from dataclasses import asdict, dataclass, field
from enum import Enum
from itertools import chain
from pathlib import Path
from typing import Any, Callable, Dict, Optional, Tuple
from typing import Any, Callable, Optional
import datasets
import evaluate
@@ -651,8 +650,8 @@ def main():
# define step functions
def train_step(
state: train_state.TrainState, batch: Dict[str, Array], dropout_rng: PRNGKey
) -> Tuple[train_state.TrainState, float]:
state: train_state.TrainState, batch: dict[str, Array], dropout_rng: PRNGKey
) -> tuple[train_state.TrainState, float]:
"""Trains model with an optimizer (both in `state`) on `batch`, returning a pair `(new_state, loss)`."""
dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
targets = batch.pop("labels")