nlp/exp1_fenci/bert/predict_bert.py
2026-04-29 18:34:27 +08:00

130 lines
4.0 KiB
Python

"""用训练好的 BERT-BMES 做 test 分词。"""
import csv, torch
import torch.nn as nn
from transformers import BertTokenizerFast, BertModel
CKPT = "ckpt/best.pt"
TOK_DIR = "ckpt"
MODEL_NAME = "hfl/chinese-bert-wwm-ext"
TEST = "../test.csv"
OUT = "../submission_bert.csv"
MAX_LEN = 510
BATCH = 32
LABELS = ["B", "M", "E", "S"]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class BertTagger(nn.Module):
def __init__(self, name, n):
super().__init__()
self.bert = BertModel.from_pretrained(name)
self.drop = nn.Dropout(0.1)
self.cls = nn.Linear(self.bert.config.hidden_size, n)
def forward(self, ids, attn):
return self.cls(self.drop(self.bert(ids, attention_mask=attn).last_hidden_state))
def chunk_chars(chars, max_len):
return [chars[i:i + max_len] for i in range(0, len(chars), max_len)]
def fix_tags(tags):
out = []; prev = None
for t in tags:
if prev in ("B", "M") and t in ("B", "S"):
if prev == "B": out[-1] = "S"
else: out[-1] = "E"
out.append(t); prev = t
return out
def tags_to_words(chars, tags):
words, buf = [], ""
for c, t in zip(chars, tags):
if t == "B":
if buf: words.append(buf)
buf = c
elif t == "M":
buf += c
elif t == "E":
buf += c; words.append(buf); buf = ""
else:
if buf: words.append(buf); buf = ""
words.append(c)
if buf: words.append(buf)
return words
def transfer(words):
cnt = 0; out = []
for w in words:
idx = list(range(cnt, cnt + len(w)))
out.append(str(idx).replace(" ", ""))
cnt += len(w)
return " ".join(out)
def main():
tok = BertTokenizerFast.from_pretrained(TOK_DIR)
model = BertTagger(MODEL_NAME, len(LABELS)).to(device)
model.load_state_dict(torch.load(CKPT, map_location=device))
model.eval()
with open(TEST, encoding="utf-8") as f:
reader = csv.reader(f); next(reader)
rows = [r for r in reader if len(r) >= 2]
jobs = []
all_tags = {}
for i, (_, sent) in enumerate(rows):
chars = list(sent)
for s, sub in enumerate(chunk_chars(chars, MAX_LEN)):
jobs.append((i, s, sub))
jobs.sort(key=lambda x: len(x[2]))
with torch.no_grad():
for b in range(0, len(jobs), BATCH):
batch = jobs[b:b + BATCH]
chars_list = [j[2] for j in batch]
enc = tok(chars_list, is_split_into_words=True, truncation=True,
max_length=MAX_LEN + 2, padding=True, return_tensors="pt")
ids = enc["input_ids"].to(device)
attn = enc["attention_mask"].to(device)
with torch.amp.autocast("cuda", dtype=torch.float16):
logits = model(ids, attn)
preds = logits.argmax(-1).cpu().numpy()
for k, (i, s, chars) in enumerate(batch):
wids = enc.word_ids(batch_index=k)
tags = []
for j, wid in enumerate(wids):
if wid is None: continue
if len(tags) == wid:
tags.append(LABELS[preds[k][j]])
all_tags[(i, s)] = tags
if (b // BATCH) % 20 == 0:
print(f" {b}/{len(jobs)}")
with open(OUT, "w", encoding="utf-8", newline="") as fo:
w = csv.writer(fo)
w.writerow(["ID", "expected"])
for i, (sid, sent) in enumerate(rows):
chars = list(sent)
full_tags = []
s = 0
while (i, s) in all_tags:
full_tags.extend(all_tags[(i, s)])
s += 1
if len(full_tags) != len(chars):
while len(full_tags) < len(chars):
full_tags.append("S")
full_tags = full_tags[:len(chars)]
full_tags = fix_tags(full_tags)
words = tags_to_words(chars, full_tags)
w.writerow([sid, transfer(words)])
print(f"wrote {OUT}")
if __name__ == "__main__":
main()