130 lines
4.0 KiB
Python
130 lines
4.0 KiB
Python
"""用训练好的 BERT-BMES 做 test 分词。"""
|
|
import csv, torch
|
|
import torch.nn as nn
|
|
from transformers import BertTokenizerFast, BertModel
|
|
|
|
CKPT = "ckpt/best.pt"
|
|
TOK_DIR = "ckpt"
|
|
MODEL_NAME = "hfl/chinese-bert-wwm-ext"
|
|
TEST = "../test.csv"
|
|
OUT = "../submission_bert.csv"
|
|
MAX_LEN = 510
|
|
BATCH = 32
|
|
LABELS = ["B", "M", "E", "S"]
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
class BertTagger(nn.Module):
|
|
def __init__(self, name, n):
|
|
super().__init__()
|
|
self.bert = BertModel.from_pretrained(name)
|
|
self.drop = nn.Dropout(0.1)
|
|
self.cls = nn.Linear(self.bert.config.hidden_size, n)
|
|
def forward(self, ids, attn):
|
|
return self.cls(self.drop(self.bert(ids, attention_mask=attn).last_hidden_state))
|
|
|
|
|
|
def chunk_chars(chars, max_len):
|
|
return [chars[i:i + max_len] for i in range(0, len(chars), max_len)]
|
|
|
|
|
|
def fix_tags(tags):
|
|
out = []; prev = None
|
|
for t in tags:
|
|
if prev in ("B", "M") and t in ("B", "S"):
|
|
if prev == "B": out[-1] = "S"
|
|
else: out[-1] = "E"
|
|
out.append(t); prev = t
|
|
return out
|
|
|
|
|
|
def tags_to_words(chars, tags):
|
|
words, buf = [], ""
|
|
for c, t in zip(chars, tags):
|
|
if t == "B":
|
|
if buf: words.append(buf)
|
|
buf = c
|
|
elif t == "M":
|
|
buf += c
|
|
elif t == "E":
|
|
buf += c; words.append(buf); buf = ""
|
|
else:
|
|
if buf: words.append(buf); buf = ""
|
|
words.append(c)
|
|
if buf: words.append(buf)
|
|
return words
|
|
|
|
|
|
def transfer(words):
|
|
cnt = 0; out = []
|
|
for w in words:
|
|
idx = list(range(cnt, cnt + len(w)))
|
|
out.append(str(idx).replace(" ", ""))
|
|
cnt += len(w)
|
|
return " ".join(out)
|
|
|
|
|
|
def main():
|
|
tok = BertTokenizerFast.from_pretrained(TOK_DIR)
|
|
model = BertTagger(MODEL_NAME, len(LABELS)).to(device)
|
|
model.load_state_dict(torch.load(CKPT, map_location=device))
|
|
model.eval()
|
|
|
|
with open(TEST, encoding="utf-8") as f:
|
|
reader = csv.reader(f); next(reader)
|
|
rows = [r for r in reader if len(r) >= 2]
|
|
|
|
jobs = []
|
|
all_tags = {}
|
|
for i, (_, sent) in enumerate(rows):
|
|
chars = list(sent)
|
|
for s, sub in enumerate(chunk_chars(chars, MAX_LEN)):
|
|
jobs.append((i, s, sub))
|
|
|
|
jobs.sort(key=lambda x: len(x[2]))
|
|
|
|
with torch.no_grad():
|
|
for b in range(0, len(jobs), BATCH):
|
|
batch = jobs[b:b + BATCH]
|
|
chars_list = [j[2] for j in batch]
|
|
enc = tok(chars_list, is_split_into_words=True, truncation=True,
|
|
max_length=MAX_LEN + 2, padding=True, return_tensors="pt")
|
|
ids = enc["input_ids"].to(device)
|
|
attn = enc["attention_mask"].to(device)
|
|
with torch.amp.autocast("cuda", dtype=torch.float16):
|
|
logits = model(ids, attn)
|
|
preds = logits.argmax(-1).cpu().numpy()
|
|
for k, (i, s, chars) in enumerate(batch):
|
|
wids = enc.word_ids(batch_index=k)
|
|
tags = []
|
|
for j, wid in enumerate(wids):
|
|
if wid is None: continue
|
|
if len(tags) == wid:
|
|
tags.append(LABELS[preds[k][j]])
|
|
all_tags[(i, s)] = tags
|
|
if (b // BATCH) % 20 == 0:
|
|
print(f" {b}/{len(jobs)}")
|
|
|
|
with open(OUT, "w", encoding="utf-8", newline="") as fo:
|
|
w = csv.writer(fo)
|
|
w.writerow(["ID", "expected"])
|
|
for i, (sid, sent) in enumerate(rows):
|
|
chars = list(sent)
|
|
full_tags = []
|
|
s = 0
|
|
while (i, s) in all_tags:
|
|
full_tags.extend(all_tags[(i, s)])
|
|
s += 1
|
|
if len(full_tags) != len(chars):
|
|
while len(full_tags) < len(chars):
|
|
full_tags.append("S")
|
|
full_tags = full_tags[:len(chars)]
|
|
full_tags = fix_tags(full_tags)
|
|
words = tags_to_words(chars, full_tags)
|
|
w.writerow([sid, transfer(words)])
|
|
print(f"wrote {OUT}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|