티스토리 뷰
gensim library를 활용하여 embedding vetor를 train 해보자.
Data load
from gensim.models import Word2Vec, KeyedVectors
from torchtext import data, datasets
train_iter = datasets.IMDB(split='train')
train_text = []
for label, text in train_iter:
train_text.append(text.lower().replace('<br />',''))
torchtext dataset에 있는 IMDB 데이터를 활용해 model을 train했다.
Train
- sg가 1이면 skip-gram, 0이면 CBOW
model = Word2Vec(sentences=train_text, vector_size=10, window=5, min_count=1, workers=4, sg=1)
저장/불러오기
model.wv.save_word2vec_format('eng_w2v') # 모델 저장
word2vec_model = KeyedVectors.load_word2vec_format("eng_w2v") # 모델 로드
결과 확인
word2vec_model['hello']
'''
array([ 0.5766802 , -0.54731876, 0.87979925, -0.65630066, -0.26262784,
-0.2619762 , 0.9157015 , 0.5002052 , -0.42661113, -0.10241051],
dtype=float32)
'''
word2vec_model.most_similar('i')
'''
[('personally', 0.995750904083252),
('myself,', 0.9852892756462097),
('it:', 0.9845582246780396),
('honestly', 0.9828792810440063),
('amazed', 0.9797829985618591),
('reviewing', 0.9775915145874023),
('it;', 0.9750117659568787),
('it.i', 0.9744855761528015),
('personally,', 0.9740839600563049),
('"this', 0.9732935428619385)]
'''
Embedding layer로 사용하기
- pytorch nn.Embedding을 대체해 사용하는 경우
import torch
from torch import nn
embedding_layer = nn.Embedding.from_pretrained(torch.from_numpy(word2vec_model.vectors), freeze=True)
input = torch.LongTensor([[word2vec_model.get_index('hello')]])
embedding_layer(input)
'''
tensor([[ 0.5767, -0.5473, 0.8798, -0.6563, -0.2626, -0.2620, 0.9157, 0.5002,
-0.4266, -0.1024]])
'''
- pytorch nn.EmbeddingBag을 대체해 사용하는 경우
embedding_layer_2 = nn.EmbeddingBag.from_pretrained(torch.from_numpy(word2vec_model.vectors), freeze=True)
input = torch.LongTensor([[word2vec_model.get_index('hello')]])
embedding_layer_2(input)
'''
tensor([[ 0.5767, -0.5473, 0.8798, -0.6563, -0.2626, -0.2620, 0.9157, 0.5002,
-0.4266, -0.1024]])
'''
'NLP' 카테고리의 다른 글
huggingface Transformer 학습 시 생성되는 checkpoint (1) | 2024.04.01 |
---|---|
tokenizer train (0) | 2024.03.29 |
huggingface repository create / delete / clone / push (0) | 2024.03.28 |
[NLP] word embedding - CBOW (0) | 2023.03.04 |
댓글