"""用训练好的 BERT-BMES 做 test 分词。""" import csv, torch import torch.nn as nn from transformers import BertTokenizerFast, BertModel CKPT = "ckpt/best.pt" TOK_DIR = "ckpt" MODEL_NAME = "hfl/chinese-bert-wwm-ext" TEST = "../test.csv" OUT = "../submission_bert.csv" MAX_LEN = 510 BATCH = 32 LABELS = ["B", "M", "E", "S"] device = torch.device("cuda" if torch.cuda.is_available() else "cpu") class BertTagger(nn.Module): def __init__(self, name, n): super().__init__() self.bert = BertModel.from_pretrained(name) self.drop = nn.Dropout(0.1) self.cls = nn.Linear(self.bert.config.hidden_size, n) def forward(self, ids, attn): return self.cls(self.drop(self.bert(ids, attention_mask=attn).last_hidden_state)) def chunk_chars(chars, max_len): return [chars[i:i + max_len] for i in range(0, len(chars), max_len)] def fix_tags(tags): out = []; prev = None for t in tags: if prev in ("B", "M") and t in ("B", "S"): if prev == "B": out[-1] = "S" else: out[-1] = "E" out.append(t); prev = t return out def tags_to_words(chars, tags): words, buf = [], "" for c, t in zip(chars, tags): if t == "B": if buf: words.append(buf) buf = c elif t == "M": buf += c elif t == "E": buf += c; words.append(buf); buf = "" else: if buf: words.append(buf); buf = "" words.append(c) if buf: words.append(buf) return words def transfer(words): cnt = 0; out = [] for w in words: idx = list(range(cnt, cnt + len(w))) out.append(str(idx).replace(" ", "")) cnt += len(w) return " ".join(out) def main(): tok = BertTokenizerFast.from_pretrained(TOK_DIR) model = BertTagger(MODEL_NAME, len(LABELS)).to(device) model.load_state_dict(torch.load(CKPT, map_location=device)) model.eval() with open(TEST, encoding="utf-8") as f: reader = csv.reader(f); next(reader) rows = [r for r in reader if len(r) >= 2] jobs = [] all_tags = {} for i, (_, sent) in enumerate(rows): chars = list(sent) for s, sub in enumerate(chunk_chars(chars, MAX_LEN)): jobs.append((i, s, sub)) jobs.sort(key=lambda x: len(x[2])) with torch.no_grad(): for b in range(0, len(jobs), BATCH): batch = jobs[b:b + BATCH] chars_list = [j[2] for j in batch] enc = tok(chars_list, is_split_into_words=True, truncation=True, max_length=MAX_LEN + 2, padding=True, return_tensors="pt") ids = enc["input_ids"].to(device) attn = enc["attention_mask"].to(device) with torch.amp.autocast("cuda", dtype=torch.float16): logits = model(ids, attn) preds = logits.argmax(-1).cpu().numpy() for k, (i, s, chars) in enumerate(batch): wids = enc.word_ids(batch_index=k) tags = [] for j, wid in enumerate(wids): if wid is None: continue if len(tags) == wid: tags.append(LABELS[preds[k][j]]) all_tags[(i, s)] = tags if (b // BATCH) % 20 == 0: print(f" {b}/{len(jobs)}") with open(OUT, "w", encoding="utf-8", newline="") as fo: w = csv.writer(fo) w.writerow(["ID", "expected"]) for i, (sid, sent) in enumerate(rows): chars = list(sent) full_tags = [] s = 0 while (i, s) in all_tags: full_tags.extend(all_tags[(i, s)]) s += 1 if len(full_tags) != len(chars): while len(full_tags) < len(chars): full_tags.append("S") full_tags = full_tags[:len(chars)] full_tags = fix_tags(full_tags) words = tags_to_words(chars, full_tags) w.writerow([sid, transfer(words)]) print(f"wrote {OUT}") if __name__ == "__main__": main()