금융 데이터 분석가/text-analysis

[NLP] MBart fine-tuning을 위한 데이터 생성

리치즈 2023. 10. 28. 07:29
728x90

huggingFace의 facebook/mbart-large-50-many-to-many-mmt 모델을 fine-tuning 하기 위한 데이터 생성하는 법을 소개합니다.

 

이 모델은 huggingface 모델 허브에서 가장 많은 다운로드를 받은 모델 중 하나입니다.

 

데이터 다운로드

필요한 데이터를 다운로드 준비합니다.

AI-Hub에서 한국어-영어 번역 말뭉치(사회과학) 데이터를 사용했습니다.

 

AI-Hub

※ 내국인만 데이터 신청이 가능합니다. 데이터 개요 데이터 변경이력 데이터 변경이력 버전 일자 변경내용 비고 1.3 2022-01-24 데이터 추가 개방 1.2 2021-08-02 데이터 추가 개방 1.1 2021-07-20 데이터

www.aihub.or.kr

다운로드 받으면 train 데이터와 validation 데이터가 각각 csv 형태로 있습니다.

  • 1113_social_train_set_1210529.csv
  • 1113_social_valid_set_151316.csv

 

데이터셋 불러오기

datasets 라이브러리로 데이터셋 형태로 불러옵니다.

from datasets import Dataset
train = Dataset.from_csv(f'{data_dir}/1113_social_train_set_1210529.csv')
valid = Dataset.from_csv(f'{data_dir}/1113_social_valid_set_151316.csv')

print(train, valid)

Dataset 형태로 잘 불러온 것을 확인할 수 있습니다.

 

translation 생성

번역 모델에 쉽게 넣기 위해서는 translation 피처를 생성해주는 것이 좋습니다.

 

예를 들면, 번역에 많이 사용되는 bible_para 데이터셋은 id와 translation 피처로 구성되어 있습니다.

huggingface dataset: bible_para

add_column 메서드를 사용해서 피처를 생성해줍니다.

# translation 컬럼 생성
train = train.add_column('translation', [{'en': x, 'ko': y} for x, y in zip(train['en'], train['ko'])])
valid = valid.add_column('translation', [{'en': x, 'ko': y} for x, y in zip(valid['en'], valid['ko'])])

# translation 컬럼만 남겨줌
train = train.select_columns(['translation'])
valid = valid.select_columns(['translation'])

print(train.feature)

 

features 메서드로 특성을 확인해보면 "en"과 "ko" 각각 Value로 되어 있는걸 볼 수 있습니다.

 

 

translation 형식으로 특성 변경

transformer에서는 translation feature를 제공합니다.

번역 전 언어코드와 문장, 번역 후 언어 코드와 문장이 딕셔너리 형태로 저장되어있습니다.

 

cast_column으로 형식을 변경해 줄 수 있습니다.

변경후에 "en"과 "ko" 각각이 아닌 Translation 형식이 된 것을 확인할 수 있습니다.

train = train.cast_column('translation', Translation(languages=['en', 'ko']))
valid = valid.cast_column('translation', Translation(languages=['en', 'ko']))

print(train.features)

 

 


전체 코드

from datasets import Dataset

train = Dataset.from_csv(f'{data_dir}/1113_social_train_set_1210529.csv')
valid = Dataset.from_csv(f'{data_dir}/1113_social_valid_set_151316.csv')

train = train.add_column('translation', [{'en': x, 'ko': y} for x, y in zip(train['en'], train['ko'])])
valid = valid.add_column('translation', [{'en': x, 'ko': y} for x, y in zip(valid['en'], valid['ko'])])

train = train.select_columns(['translation'])
valid = valid.select_columns(['translation'])

train = train.cast_column('translation', Translation(languages=['en', 'ko']))
valid = valid.cast_column('translation', Translation(languages=['en', 'ko']))

 

 

 

 

 

 

 

728x90
LIST