[go: up one dir, main page]

Skip to content

Commit

Permalink
improve batch processing for inference
Browse files Browse the repository at this point in the history
  • Loading branch information
Ingvarstep committed Sep 6, 2024
1 parent 7495d0a commit 168910f
Showing 1 changed file with 48 additions and 51 deletions.
99 changes: 48 additions & 51 deletions gliner/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def resize_token_embeddings(

return model_embeds

def prepare_model_inputs(self, texts: List[str], labels: List[str], prepare_entities: bool = True):
def prepare_texts(self, texts: List[str]):
"""
Prepare inputs for the model.
Expand All @@ -156,11 +156,6 @@ def prepare_model_inputs(self, texts: List[str], labels: List[str], prepare_enti
all_tokens = []
all_start_token_idx_to_text_idx = []
all_end_token_idx_to_text_idx = []
# preserving the order of labels
labels = list(dict.fromkeys(labels))

class_to_ids = {k: v for v, k in enumerate(labels, start=1)}
id_to_classes = {k: v for v, k in class_to_ids.items()}

for text in texts:
tokens = []
Expand All @@ -175,33 +170,8 @@ def prepare_model_inputs(self, texts: List[str], labels: List[str], prepare_enti
all_end_token_idx_to_text_idx.append(end_token_idx_to_text_idx)

input_x = [{"tokenized_text": tk, "ner": None} for tk in all_tokens]
raw_batch = self.data_processor.collate_raw_batch(input_x, labels,
class_to_ids = class_to_ids,
id_to_classes = id_to_classes)
raw_batch["all_start_token_idx_to_text_idx"] = all_start_token_idx_to_text_idx
raw_batch["all_end_token_idx_to_text_idx"] = all_end_token_idx_to_text_idx

model_input = self.data_processor.collate_fn(raw_batch, prepare_labels=False,
prepare_entities=prepare_entities)
model_input.update(
{
"span_idx": raw_batch["span_idx"] if "span_idx" in raw_batch else None,
"span_mask": raw_batch["span_mask"]
if "span_mask" in raw_batch
else None,
"text_lengths": raw_batch["seq_length"],
}
)

if not self.onnx_model:
device = self.device
for key in model_input:
if model_input[key] is not None and isinstance(
model_input[key], torch.Tensor
):
model_input[key] = model_input[key].to(device)

return model_input, raw_batch
return input_x, all_start_token_idx_to_text_idx, all_end_token_idx_to_text_idx

def predict_entities(
self, text, labels, flat_ner=True, threshold=0.5, multi_label=False
Expand Down Expand Up @@ -229,7 +199,7 @@ def predict_entities(

@torch.no_grad()
def batch_predict_entities(
self, texts, labels, flat_ner=True, threshold=0.5, multi_label=False
self, texts, labels, flat_ner=True, threshold=0.5, multi_label=False, batch_size=8
):
"""
Predict entities for a batch of texts.
Expand All @@ -244,29 +214,53 @@ def batch_predict_entities(
Returns:
The list of lists with predicted entities.
"""
self.eval()
# raw input preparation
input_x, all_start_token_idx_to_text_idx, all_end_token_idx_to_text_idx = self.prepare_texts(texts)

model_input, raw_batch = self.prepare_model_inputs(texts, labels)
labels = list(dict.fromkeys(labels))

collator = DataCollator(
self.config,
data_processor=self.data_processor,
return_tokens=True,
return_entities=True,
return_id_to_classes=True,
prepare_labels=False,
entity_types=labels,
)
data_loader = torch.utils.data.DataLoader(
input_x, batch_size=batch_size, shuffle=False, collate_fn=collator
)

model_output = self.model(**model_input)[0]
outputs = []
# Iterate over data batches
for batch in data_loader:
# Move the batch to the appropriate device
for key in batch:
if isinstance(batch[key], torch.Tensor):
batch[key] = batch[key].to(self.device)

if not isinstance(model_output, torch.Tensor):
model_output = torch.from_numpy(model_output)
# Perform predictions
model_output = self.model(**batch)[0]

outputs = self.decoder.decode(
raw_batch["tokens"],
raw_batch["id_to_classes"],
model_output,
flat_ner=flat_ner,
threshold=threshold,
multi_label=multi_label,
)
if not isinstance(model_output, torch.Tensor):
model_output = torch.from_numpy(model_output)

decoded_outputs = self.decoder.decode(
batch["tokens"],
batch["id_to_classes"],
model_output,
flat_ner=flat_ner,
threshold=threshold,
multi_label=multi_label,
)
outputs.extend(decoded_outputs)

all_entities = []
for i, output in enumerate(outputs):
start_token_idx_to_text_idx = raw_batch["all_start_token_idx_to_text_idx"][
i
]
end_token_idx_to_text_idx = raw_batch["all_end_token_idx_to_text_idx"][i]
start_token_idx_to_text_idx = all_start_token_idx_to_text_idx[i]
end_token_idx_to_text_idx = all_end_token_idx_to_text_idx[i]
entities = []
for start_token_idx, end_token_idx, ent_type, ent_score in output:
start_text_idx = start_token_idx_to_text_idx[start_token_idx]
Expand Down Expand Up @@ -437,7 +431,6 @@ def evaluate(
)
all_preds.extend(decoded_outputs)
all_trues.extend(batch["entities"])

# Evaluate the predictions
evaluator = Evaluator(all_trues, all_preds)
out, f1 = evaluator.evaluate()
Expand Down Expand Up @@ -623,6 +616,7 @@ def _from_pretrained(
_attn_implementation: Optional[str] = None,
max_length: Optional[int] = None,
max_width: Optional[int] = None,
post_fusion_schema: Optional[str] = None,
**model_kwargs,
):
"""
Expand Down Expand Up @@ -682,7 +676,10 @@ def _from_pretrained(
config.max_len = max_length
if max_width is not None:
config.max_width = max_width

if post_fusion_schema is not None:
config.post_fusion_schema = post_fusion_schema
print('Post fusion is set.')

add_tokens = ["[FLERT]", config.ent_token, config.sep_token]

if not load_onnx_model:
Expand Down

0 comments on commit 168910f

Please sign in to comment.