새소식

부스트캠프 AI Tech 4기

How to resize_embedding size with JAX

  • -

tokenizer에 tokens를 추가하는 과정에서 jax에는 pytorch의 resize_token_embeddings 같은 method가 구현되어있지 않은 것 같아 보여서 embedding size를 어떻게 바꿔주어야 할지 고민해보았다.

 

model의 embedding 파라미터에 random initalize한 추가해줘야하는 만큼의 shape을 만들어주어서 concat해주는 방식으로 shape을 맞춰주었다.

 

if model_args.model_name_or_path:
        model = FlaxT5ForConditionalGeneration.from_pretrained(
            model_args.model_name_or_path,
            config=config,
            seed=training_args.seed,
            dtype=getattr(jnp, model_args.dtype),
            use_auth_token=True if model_args.use_auth_token else None,
            from_pt=True,  # pytorch 가중치 파일 사용
        )

        # add new tokens
        # fmt: off
        new_tokens = ['BTS', 'DM', 'RJ', 'RM', 'bts', 'rm', 'ㅂ6탄', 'ㅂㅁㄱ', 'ㅂㅋ', 'ㅇㄱㄹ', '강아띠', '개가티', '게티', '공계', '공굿', '구기', '구오즈', '국민', '국프', '그므시라꼬', '김남준', '김석진', '김태형', '깐진', '꾸꾸', '남듀나', '남듀니', '남주니', '남준', '남준이', '뉸기', '늉기', '단밤', '달방', '덕계못', '덕메', '덕밍아웃', '덕질', '덕통사고', '뎀', '디민', '디밍', '디엠', '딤읭이', '딤인', '랜포', '마플', '막콘', '망', '머글', '무나', '미늉기', '민', '민군', '민윤기', '민피디', '바이브', '바찌미', '박디민', '박라이', '박지미', '박지민', '박짐', '반', '방나잇', '방모닝', '병먹금', '병크', '보라해', '본보', '본보1', '본보2', '부계', '뷔', '브마', '비공굿', '비티엣스', '빅히트', '빛티에스', '뾰아리', '쀠', '사랑', '새꾸', '서수', '서폿', '석지니', '석진이', '석찌나', '성덕', '슈가', '슈짐', '슈키', '슉아', '스밍', '스얼라', '슥찌', '실검', '실트', '싸빙', '써방', '아미', '악개', '안방수니', '앙콘', '애깅', '양도계', '어그로', '어글호', '어덕행덕', '어태범', '연검', '연타니', '연탄', '옛둥이', '오프', '올출', '올콘', '와꾸', '우떠', '우래기', '위버스', '윤기', '융긔', '이선좌', '인별', '인죵', '일코', '입덕', '자마니', '잡덕', '전봉장', '전정국', '전졍국', '전증구기', '정구기', '정국', '정궁이', '정꾸', '정꾸기', '정호석', '제이홉', '졍큑', '조공', '존버', '주니', '줴멘', '지니', '지니애깅', '지미나', '지미니', '지민', '진', '진희', '짐니', '짐몽', '짐순이', '짐쨩', '짐프', '집콘', '짜마니', '쨔마니', '쨔만', '쮀멘', '쮸니', '찌미나', '찜니', '찜프', '차애', '총공', '총공계', '최애', '취케팅', '치미', '코야', '쿠키', '크트', '타래', '타타', '탄이', '탈덕', '탐라', '탑시드', '태깅', '태태', '태형이', '태효이', '택포', '텽이', '톡희', '틧', '티롱이', '티횽이', '평아', '평짐', '포도알', '포카', '플미', '피케팅', '핑몬', '하라메', '하이브', '햄찌', '햅뷔', '햅삐짐데이', '호바', '호비', '호서기', '호서긱', '호석이', '호시기', '홉', '휴덕', '정구', '방탄소년단', '방탄',  'love', 'yourself', '다이너마이트',]
        #fmt: off
        
        added_token_num = tokenizer.add_tokens(new_tokens)
        
        # change FlaxT5ForConditionalGeneration token_embedding resize
        model.config.vocab_size = len(tokenizer)

        initializer = jax.nn.initializers.kaiming_normal()
        # add new token embedding
        # new_token_embedding = jnp.zeros((added_token_num, model.config.d_model), dtype=jnp.float32) # (217, 768)
        new_token_embedding = initializer(rng, (added_token_num, model.config.d_model), dtype=jnp.float32)
        model.params["shared"]["embedding"] = jnp.concatenate([model.params["shared"]["embedding"], new_token_embedding], axis=0)

        # add lm_head kernel for new tokens
        # new_lm_head_kernel = jnp.zeros((model.config.d_model, added_token_num), dtype=jnp.float32) # (768, 217)
        new_lm_head_kernel = initializer(rng, (model.config.d_model, added_token_num), dtype=jnp.float32)  # (768, 217)
        model.params["lm_head"]["kernel"] = jnp.concatenate([model.params["lm_head"]["kernel"], new_lm_head_kernel], axis=1) # (768, 50575)

 

model의 params의 key값들을 아래에서 확인할 수 있었다.

print(model.params.keys()) # dict_keys(['shared', 'encoder', 'decoder', 'lm_head'])
print(model.params["shared"].keys()) # dict_keys(['embedding'])
print(model.params["encoder"].keys()) # dict_keys(['block', 'final_layer_norm'])
print(model.params["decoder"].keys()) # dict_keys(['block', 'final_layer_norm'])
print(model.params["lm_head"].keys()) # dict_keys(["kernel"])
728x90
Contents