새소식

부스트캠프 AI Tech 4기

Pretraining T5 with JAX

  • -

Jax GPU CUDA 설치

pip install --upgrade pip
# Installs the wheel compatible with CUDA 11 and cuDNN 8.6 or newer.
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

https://github.com/google/jax/#pip-installation-gpu-cuda

 

GitHub - google/jax: Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more - GitHub - google/jax: Composable transformations of Python+NumPy programs: differentiate, ve...

github.com

 

 

pretraining T5 with JAX from Pytorch PLM 

참고 코드 링크 : https://github.com/huggingface/transformers/tree/main/examples/flax/language-modeling

 

input 형식에 맞춰 tokenize 함수 수정

def tokenize_function(examples):
		return tokenizer(examples[text_column_name], return_attention_mask=False)

↓

def tokenize_function(examples):
    inputs = tokenizer(
        examples["Q"],
        padding="max_length",
        truncation=True,
        max_length=max_seq_length,
        return_tensors="pt",
        return_token_type_ids=False,
        return_attention_mask=False,
    )
    return {**inputs}

 

Pytorch 프레임워크로 학습된 모델 로드

model = FlaxT5ForConditionalGeneration.from_pretrained(
		    model_args.model_name_or_path,
		    config=config,
		    seed=training_args.seed,
		    dtype=getattr(jnp, model_args.dtype),
		    use_auth_token=True if model_args.use_auth_token else None,
		)

↓

model = FlaxT5ForConditionalGeneration.from_pretrained(
		    model_args.model_name_or_path,
		    config=config,
		    seed=training_args.seed,
		    dtype=getattr(jnp, model_args.dtype),
		    use_auth_token=True if model_args.use_auth_token else None,
		    from_pt=True, # Pytroch 프레임워크로 학습한 모델 불러오는 인자 추가
		)

 

Run

python run_t5_mlm_flax.py \
	--num_train_epochs=3.0
	--output_dir="saved_models/psyche/KoT5" \
	--model_type="t5" \
	--model_name_or_path="psyche/KoT5" \
	--tokenizer_name="psyche/KoT5" \
	--dataset_name="" \
	--max_seq_length="256" \
	--per_device_train_batch_size="32" \
	--per_device_eval_batch_size="32" \
	--weight_decay="0.001" \
	--warmup_steps="2000" \
	--logging_steps="2500" \
	--save_steps="50" \
	--eval_steps="2500"

 

 

Load Model

Jax의 경우 저장한 모델 확장자명이 .msgpack의 형태이다.from_flax=True 값으로 불러올 수 있다.

model = AutoModelForSeq2SeqLM.from_pretrained(model_path, from_flax=True)

 

 

 

728x90
Contents