"""BERT 字符级 BMES 分词训练。""" import csv, os, random, json, time import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader from transformers import BertTokenizerFast, BertModel, get_linear_schedule_with_warmup from torch.optim import AdamW TRAIN = "../train.csv" MODEL_NAME = "hfl/chinese-bert-wwm-ext" SAVE_DIR = "ckpt" MAX_LEN = 128 BATCH = 32 EPOCHS = 3 LR = 3e-5 SEED = 42 VAL_RATIO = 0.02 LABELS = ["B", "M", "E", "S"] L2I = {l: i for i, l in enumerate(LABELS)} PAD_ID = -100 random.seed(SEED); torch.manual_seed(SEED) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def tag_word(w): if len(w) == 1: return ["S"] return ["B"] + ["M"] * (len(w) - 2) + ["E"] def load_pairs(): pairs = [] with open(TRAIN, encoding="utf-8") as f: r = csv.reader(f); next(r) for row in r: if not row: continue words = [w for w in row[0].strip().split(" ") if w] if not words: continue tags, chars = [], [] for w in words: tags.extend(tag_word(w)); chars.extend(list(w)) pairs.append((chars, tags)) return pairs class SegDS(Dataset): def __init__(self, pairs, tok): self.pairs, self.tok = pairs, tok def __len__(self): return len(self.pairs) def __getitem__(self, i): chars, tags = self.pairs[i] chars = chars[:MAX_LEN - 2] tags = tags[:MAX_LEN - 2] enc = self.tok(chars, is_split_into_words=True, truncation=True, max_length=MAX_LEN, padding="max_length", return_tensors="pt") input_ids = enc["input_ids"].squeeze(0) attn = enc["attention_mask"].squeeze(0) word_ids = enc.word_ids(batch_index=0) labels = [] for wid in word_ids: if wid is None: labels.append(PAD_ID) else: labels.append(L2I[tags[wid]]) labels = torch.tensor(labels, dtype=torch.long) return input_ids, attn, labels class BertTagger(nn.Module): def __init__(self, name, n_labels): super().__init__() self.bert = BertModel.from_pretrained(name) self.drop = nn.Dropout(0.1) self.cls = nn.Linear(self.bert.config.hidden_size, n_labels) def forward(self, ids, attn): out = self.bert(ids, attention_mask=attn).last_hidden_state return self.cls(self.drop(out)) def seg_f1(gold_tags, pred_tags): def to_segs(tags): r = []; start = None for i, t in enumerate(tags): if t == "B": if start is not None: r.append((start, i - 1)) start = i elif t == "M": if start is None: start = i elif t == "E": if start is None: start = i r.append((start, i)); start = None elif t == "S": if start is not None: r.append((start, i - 1)) r.append((i, i)); start = None if start is not None: r.append((start, len(tags) - 1)) return set(r) g, p = to_segs(gold_tags), to_segs(pred_tags) tp = len(g & p) if tp == 0: return 0.0 P, R = tp / len(p), tp / len(g) return 2 * P * R / (P + R) def run_val(model, loader): model.eval() f1s = [] with torch.no_grad(): for ids, attn, labels in loader: ids, attn = ids.to(device), attn.to(device) logits = model(ids, attn) preds = logits.argmax(-1).cpu().numpy() labels = labels.numpy() for p, l in zip(preds, labels): gold = [LABELS[x] for x in l if x != PAD_ID] mask = l != PAD_ID pr = [LABELS[x] for x, m in zip(p, mask) if m] if gold: f1s.append(seg_f1(gold, pr)) return sum(f1s) / max(len(f1s), 1) def main(): os.makedirs(SAVE_DIR, exist_ok=True) tok = BertTokenizerFast.from_pretrained(MODEL_NAME) pairs = load_pairs() random.shuffle(pairs) n_val = int(len(pairs) * VAL_RATIO) val, train = pairs[:n_val], pairs[n_val:] print(f"train={len(train)} val={len(val)}") tr_loader = DataLoader(SegDS(train, tok), batch_size=BATCH, shuffle=True, num_workers=2, pin_memory=True) va_loader = DataLoader(SegDS(val, tok), batch_size=BATCH * 2, num_workers=2, pin_memory=True) model = BertTagger(MODEL_NAME, len(LABELS)).to(device) opt = AdamW(model.parameters(), lr=LR, weight_decay=0.01) total_steps = len(tr_loader) * EPOCHS sched = get_linear_schedule_with_warmup(opt, int(0.1 * total_steps), total_steps) loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_ID) scaler = torch.amp.GradScaler("cuda") best = 0.0 step = 0 for epoch in range(EPOCHS): model.train() t0 = time.time() running = 0.0 for ids, attn, labels in tr_loader: ids, attn, labels = ids.to(device), attn.to(device), labels.to(device) opt.zero_grad() with torch.amp.autocast("cuda", dtype=torch.float16): logits = model(ids, attn) loss = loss_fn(logits.view(-1, len(LABELS)), labels.view(-1)) scaler.scale(loss).backward() scaler.unscale_(opt) torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) scaler.step(opt); scaler.update(); sched.step() running += loss.item(); step += 1 if step % 200 == 0: print(f" ep{epoch+1} step{step} loss={running/200:.4f} lr={sched.get_last_lr()[0]:.2e}") running = 0.0 f1 = run_val(model, va_loader) dt = time.time() - t0 print(f"[epoch {epoch+1}] val F1={f1:.4f} time={dt:.1f}s") if f1 > best: best = f1 torch.save(model.state_dict(), f"{SAVE_DIR}/best.pt") tok.save_pretrained(SAVE_DIR) with open(f"{SAVE_DIR}/labels.json", "w") as f: json.dump(LABELS, f) print(f" saved best -> {SAVE_DIR}/best.pt") print(f"best val F1 = {best:.4f}") if __name__ == "__main__": main()