nlp/exp1_fenci/bert/train_bert.py
2026-04-29 18:34:27 +08:00

178 lines
6.0 KiB
Python

"""BERT 字符级 BMES 分词训练。"""
import csv, os, random, json, time
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizerFast, BertModel, get_linear_schedule_with_warmup
from torch.optim import AdamW
TRAIN = "../train.csv"
MODEL_NAME = "hfl/chinese-bert-wwm-ext"
SAVE_DIR = "ckpt"
MAX_LEN = 128
BATCH = 32
EPOCHS = 3
LR = 3e-5
SEED = 42
VAL_RATIO = 0.02
LABELS = ["B", "M", "E", "S"]
L2I = {l: i for i, l in enumerate(LABELS)}
PAD_ID = -100
random.seed(SEED); torch.manual_seed(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def tag_word(w):
if len(w) == 1:
return ["S"]
return ["B"] + ["M"] * (len(w) - 2) + ["E"]
def load_pairs():
pairs = []
with open(TRAIN, encoding="utf-8") as f:
r = csv.reader(f); next(r)
for row in r:
if not row: continue
words = [w for w in row[0].strip().split(" ") if w]
if not words: continue
tags, chars = [], []
for w in words:
tags.extend(tag_word(w)); chars.extend(list(w))
pairs.append((chars, tags))
return pairs
class SegDS(Dataset):
def __init__(self, pairs, tok):
self.pairs, self.tok = pairs, tok
def __len__(self): return len(self.pairs)
def __getitem__(self, i):
chars, tags = self.pairs[i]
chars = chars[:MAX_LEN - 2]
tags = tags[:MAX_LEN - 2]
enc = self.tok(chars, is_split_into_words=True, truncation=True, max_length=MAX_LEN,
padding="max_length", return_tensors="pt")
input_ids = enc["input_ids"].squeeze(0)
attn = enc["attention_mask"].squeeze(0)
word_ids = enc.word_ids(batch_index=0)
labels = []
for wid in word_ids:
if wid is None:
labels.append(PAD_ID)
else:
labels.append(L2I[tags[wid]])
labels = torch.tensor(labels, dtype=torch.long)
return input_ids, attn, labels
class BertTagger(nn.Module):
def __init__(self, name, n_labels):
super().__init__()
self.bert = BertModel.from_pretrained(name)
self.drop = nn.Dropout(0.1)
self.cls = nn.Linear(self.bert.config.hidden_size, n_labels)
def forward(self, ids, attn):
out = self.bert(ids, attention_mask=attn).last_hidden_state
return self.cls(self.drop(out))
def seg_f1(gold_tags, pred_tags):
def to_segs(tags):
r = []; start = None
for i, t in enumerate(tags):
if t == "B":
if start is not None: r.append((start, i - 1))
start = i
elif t == "M":
if start is None: start = i
elif t == "E":
if start is None: start = i
r.append((start, i)); start = None
elif t == "S":
if start is not None: r.append((start, i - 1))
r.append((i, i)); start = None
if start is not None: r.append((start, len(tags) - 1))
return set(r)
g, p = to_segs(gold_tags), to_segs(pred_tags)
tp = len(g & p)
if tp == 0: return 0.0
P, R = tp / len(p), tp / len(g)
return 2 * P * R / (P + R)
def run_val(model, loader):
model.eval()
f1s = []
with torch.no_grad():
for ids, attn, labels in loader:
ids, attn = ids.to(device), attn.to(device)
logits = model(ids, attn)
preds = logits.argmax(-1).cpu().numpy()
labels = labels.numpy()
for p, l in zip(preds, labels):
gold = [LABELS[x] for x in l if x != PAD_ID]
mask = l != PAD_ID
pr = [LABELS[x] for x, m in zip(p, mask) if m]
if gold:
f1s.append(seg_f1(gold, pr))
return sum(f1s) / max(len(f1s), 1)
def main():
os.makedirs(SAVE_DIR, exist_ok=True)
tok = BertTokenizerFast.from_pretrained(MODEL_NAME)
pairs = load_pairs()
random.shuffle(pairs)
n_val = int(len(pairs) * VAL_RATIO)
val, train = pairs[:n_val], pairs[n_val:]
print(f"train={len(train)} val={len(val)}")
tr_loader = DataLoader(SegDS(train, tok), batch_size=BATCH, shuffle=True, num_workers=2, pin_memory=True)
va_loader = DataLoader(SegDS(val, tok), batch_size=BATCH * 2, num_workers=2, pin_memory=True)
model = BertTagger(MODEL_NAME, len(LABELS)).to(device)
opt = AdamW(model.parameters(), lr=LR, weight_decay=0.01)
total_steps = len(tr_loader) * EPOCHS
sched = get_linear_schedule_with_warmup(opt, int(0.1 * total_steps), total_steps)
loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_ID)
scaler = torch.amp.GradScaler("cuda")
best = 0.0
step = 0
for epoch in range(EPOCHS):
model.train()
t0 = time.time()
running = 0.0
for ids, attn, labels in tr_loader:
ids, attn, labels = ids.to(device), attn.to(device), labels.to(device)
opt.zero_grad()
with torch.amp.autocast("cuda", dtype=torch.float16):
logits = model(ids, attn)
loss = loss_fn(logits.view(-1, len(LABELS)), labels.view(-1))
scaler.scale(loss).backward()
scaler.unscale_(opt)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(opt); scaler.update(); sched.step()
running += loss.item(); step += 1
if step % 200 == 0:
print(f" ep{epoch+1} step{step} loss={running/200:.4f} lr={sched.get_last_lr()[0]:.2e}")
running = 0.0
f1 = run_val(model, va_loader)
dt = time.time() - t0
print(f"[epoch {epoch+1}] val F1={f1:.4f} time={dt:.1f}s")
if f1 > best:
best = f1
torch.save(model.state_dict(), f"{SAVE_DIR}/best.pt")
tok.save_pretrained(SAVE_DIR)
with open(f"{SAVE_DIR}/labels.json", "w") as f:
json.dump(LABELS, f)
print(f" saved best -> {SAVE_DIR}/best.pt")
print(f"best val F1 = {best:.4f}")
if __name__ == "__main__":
main()