[go: up one dir, main page]

Skip to content

Commit

Permalink
add support of decoder models with LLM2Vec package
Browse files Browse the repository at this point in the history
  • Loading branch information
Ingvarstep committed Aug 9, 2024
1 parent b2e8281 commit e291361
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 17 deletions.
67 changes: 57 additions & 10 deletions gliner/modeling/encoder.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,70 @@
import warnings
from pathlib import Path

import torch
from torch import nn
from transformers import AutoModel, AutoConfig

#just wraping to allow to load previously created models
from ..utils import is_module_available, MissedPackageException

IS_LLM2VEC = is_module_available('llm2vec')
IS_PEFT = is_module_available('peft')


if IS_LLM2VEC:
from llm2vec.models import MistralBiModel, LlamaBiModel, GemmaBiModel, Qwen2BiModel
DECODER_MODEL_MAPPING = {
"MistralConfig": MistralBiModel,
"LlamaConfig": LlamaBiModel,
"GemmaConfig": GemmaBiModel,
"Qwen2Config": Qwen2BiModel
}
else:
DECODER_MODEL_MAPPING = {}

if IS_PEFT:
from peft import LoraConfig, get_peft_model

class Transformer(nn.Module):
def __init__(self, config, from_pretrained):
super().__init__()
if config.encoder_config is not None:
encoder_config = config.encoder_config
else:
encoder_config = AutoConfig.from_pretrained(config.model_name, token="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")
if config.vocab_size!=-1:
encoder_config.vocab_size = config.vocab_size

config_name = encoder_config.__class__.__name__

if config_name in DECODER_MODEL_MAPPING:
if not IS_LLM2VEC:
raise MissedPackageException(f"The llm2vec package must be installed to use this decoder model: {config_name}")
else:
print('Loading decoder model using LLM2Vec...')
ModelClass = DECODER_MODEL_MAPPING[config_name]
decoder = True
else:
decoder = False
ModelClass = AutoModel

if from_pretrained:
self.model = AutoModel.from_pretrained(config.model_name)
self.model = ModelClass.from_pretrained(config.model_name, trust_remote_code=True, token="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")
else:
if config.encoder_config is None:
encoder_config = AutoConfig.from_pretrained(config.model_name)
if config.vocab_size!=-1:
encoder_config.vocab_size = config.vocab_size

if not decoder:
self.model = ModelClass.from_config(encoder_config, trust_remote_code=True)
else:
encoder_config = config.encoder_config
self.model = AutoModel.from_config(encoder_config)

self.model = ModelClass(encoder_config)

adapter_config_file = Path(config.model_name) / "adapter_config.json"

if adapter_config_file.exists():
if not IS_PEFT:
warnings.warn(f"Adapter configs were detected, if you want to apply them you need to install peft package.")
else:
adapter_config = LoraConfig.from_pretrained(config.model_name)
self.model = get_peft_model(self.model, adapter_config)

def forward(self, *args, **kwargs):
output = self.model(*args, **kwargs)
return output[0]
Expand Down
22 changes: 21 additions & 1 deletion gliner/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,24 @@
def load_config_as_namespace(config_file):
with open(config_file, "r") as f:
config_dict = yaml.safe_load(f)
return argparse.Namespace(**config_dict)
return argparse.Namespace(**config_dict)

def is_module_available(module_name):
"""
Checks whether the specified Python module is available.
Args:
module_name (str): The name of the module to check.
Returns:
bool: True if the module is available, False otherwise.
"""
try:
__import__(module_name)
return True
except ImportError:
return False

class MissedPackageException(Exception):
"""Raised when the requested decoder model is not supported."""
pass
22 changes: 16 additions & 6 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from gliner import GLiNERConfig, GLiNER
from gliner.training import Trainer, TrainingArguments
from gliner.data_processing.collator import DataCollatorWithPadding
from gliner.data_processing.collator import DataCollatorWithPadding, DataCollator
from gliner.utils import load_config_as_namespace
from gliner.data_processing import WordsSplitter, GLiNERDataset

Expand All @@ -19,6 +19,8 @@
parser.add_argument('--config', type=str, default= "config.yaml")
parser.add_argument('--log_dir', type=str, default = 'models/')
parser.add_argument('--compile_model', type=bool, default = False)
parser.add_argument('--freeze_language_model', type=bool, default = False)
parser.add_argument('--new_data_schema', type=bool, default = False)
args = parser.parse_args()

device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
Expand Down Expand Up @@ -49,11 +51,6 @@

words_splitter = WordsSplitter(model_config.words_splitter_type)

train_dataset = GLiNERDataset(train_data, model_config, tokenizer, words_splitter)
test_dataset = GLiNERDataset(test_data, model_config, tokenizer, words_splitter)

data_collator = DataCollatorWithPadding(model_config)

model = GLiNER(model_config, tokenizer=tokenizer, words_splitter=words_splitter)
model.resize_token_embeddings([model_config.ent_token, model_config.sep_token],
set_class_token_index = False,
Expand All @@ -64,6 +61,19 @@
model.to(device)
model.compile_for_training()

if args.freeze_language_model:
model.model.token_rep_layer.bert_layer.model.requires_grad_(False)
else:
model.model.token_rep_layer.bert_layer.model.requires_grad_(True)

if args.new_data_schema:
train_dataset = GLiNERDataset(train_data, model_config, tokenizer, words_splitter)
test_dataset = GLiNERDataset(test_data, model_config, tokenizer, words_splitter)
data_collator = DataCollatorWithPadding(model_config)
else:
train_dataset = train_data
test_dataset = test_data
data_collator = DataCollator(model.config, data_processor=model.data_processor, prepare_labels=True)

training_args = TrainingArguments(
output_dir=config.log_dir,
Expand Down

0 comments on commit e291361

Please sign in to comment.