[go: up one dir, main page]

Skip to content

Commit

Permalink
fix issues with missed tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
Ingvarstep committed Sep 1, 2024
1 parent fac1852 commit 569632a
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 6 deletions.
7 changes: 6 additions & 1 deletion gliner/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,5 +725,10 @@ def _from_pretrained(
config.class_token_index == -1 or config.vocab_size == -1
) and resize_token_embeddings:
gliner.data_processor.transformer_tokenizer.add_tokens(add_tokens)


if len(tokenizer)!=len(gliner.token_rep_layer.get_input_embeddings()):
new_num_tokens = len(model.data_processor.transformer_tokenizer)
model_embeds = gliner.model.token_rep_layer.resize_token_embeddings(
new_num_tokens, None
)
return gliner
15 changes: 10 additions & 5 deletions gliner/modeling/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def extract_word_embeddings(token_embeds, words_mask, attention_mask,


def extract_prompt_features_and_word_embeddings(config, token_embeds, input_ids, attention_mask,
text_lengths, words_mask, **kwargs):
text_lengths, words_mask, embed_ent_token = True, **kwargs):
# getting prompt embeddings
batch_size, sequence_length, embed_dim = token_embeds.shape

Expand Down Expand Up @@ -89,10 +89,15 @@ def __init__(self, config, from_pretrained = False):
self.rnn = LstmSeq2SeqEncoder(config)

if config.post_fusion_schema:
self.config.num_post_fusion_layers = 3
print('Initializing cross fuser...')
print('Post fusion layer:', config.post_fusion_schema)
print('Number of post fusion layers:', config.num_post_fusion_layers)

self.cross_fuser = CrossFuser(self.config.hidden_size,
self.config.hidden_size,
num_heads=self.token_rep_layer.bert_layer.model.config.num_attention_heads,
num_layers=1,
num_layers=self.config.num_post_fusion_layers,
dropout=config.dropout,
schema=config.post_fusion_schema)

Expand All @@ -107,7 +112,8 @@ def _extract_prompt_features_and_word_embeddings(self, token_embeds, input_ids,
input_ids,
attention_mask,
text_lengths,
words_mask)
words_mask,
self.config.embed_ent_token)
return prompts_embedding, prompts_embedding_mask, words_embedding, mask

def get_uni_representations(self,
Expand Down Expand Up @@ -194,8 +200,7 @@ def _loss(self, logits: torch.Tensor, labels: torch.Tensor,
@abstractmethod
def loss(self, x):
pass



class SpanModel(BaseModel):
def __init__(self, config, encoder_from_pretrained):
super(SpanModel, self).__init__(config, encoder_from_pretrained)
Expand Down
3 changes: 3 additions & 0 deletions gliner/modeling/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ def resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None):
return self.bert_layer.model.resize_token_embeddings(new_num_tokens,
pad_to_multiple_of)

def get_input_embeddings(self):
return self.bert_layer.model.get_input_embeddings()

def encode_text(self, input_ids, attention_mask, *args, **kwargs):
token_embeddings = self.bert_layer(input_ids, attention_mask, *args, **kwargs)
if hasattr(self, "projection"):
Expand Down

0 comments on commit 569632a

Please sign in to comment.