Jax GPU CUDA 설치
pip install --upgrade pip
# Installs the wheel compatible with CUDA 11 and cuDNN 8.6 or newer.
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
https://github.com/google/jax/#pip-installation-gpu-cuda
GitHub - google/jax: Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more - GitHub - google/jax: Composable transformations of Python+NumPy programs: differentiate, ve...
github.com
pretraining T5 with JAX from Pytorch PLM
참고 코드 링크 : https://github.com/huggingface/transformers/tree/main/examples/flax/language-modeling
input 형식에 맞춰 tokenize 함수 수정
def tokenize_function(examples):
return tokenizer(examples[text_column_name], return_attention_mask=False)
↓
def tokenize_function(examples):
inputs = tokenizer(
examples["Q"],
padding="max_length",
truncation=True,
max_length=max_seq_length,
return_tensors="pt",
return_token_type_ids=False,
return_attention_mask=False,
)
return {**inputs}
Pytorch 프레임워크로 학습된 모델 로드
model = FlaxT5ForConditionalGeneration.from_pretrained(
model_args.model_name_or_path,
config=config,
seed=training_args.seed,
dtype=getattr(jnp, model_args.dtype),
use_auth_token=True if model_args.use_auth_token else None,
)
↓
model = FlaxT5ForConditionalGeneration.from_pretrained(
model_args.model_name_or_path,
config=config,
seed=training_args.seed,
dtype=getattr(jnp, model_args.dtype),
use_auth_token=True if model_args.use_auth_token else None,
from_pt=True, # Pytroch 프레임워크로 학습한 모델 불러오는 인자 추가
)
Run
python run_t5_mlm_flax.py \
--num_train_epochs=3.0
--output_dir="saved_models/psyche/KoT5" \
--model_type="t5" \
--model_name_or_path="psyche/KoT5" \
--tokenizer_name="psyche/KoT5" \
--dataset_name="" \
--max_seq_length="256" \
--per_device_train_batch_size="32" \
--per_device_eval_batch_size="32" \
--weight_decay="0.001" \
--warmup_steps="2000" \
--logging_steps="2500" \
--save_steps="50" \
--eval_steps="2500"
Load Model
Jax의 경우 저장한 모델 확장자명이 .msgpack의 형태이다. from_flax=True 값으로 불러올 수 있다.
model = AutoModelForSeq2SeqLM.from_pretrained(model_path, from_flax=True)