178 lines
6.0 KiB
Python
178 lines
6.0 KiB
Python
"""BERT 字符级 BMES 分词训练。"""
|
|
import csv, os, random, json, time
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.utils.data import Dataset, DataLoader
|
|
from transformers import BertTokenizerFast, BertModel, get_linear_schedule_with_warmup
|
|
from torch.optim import AdamW
|
|
|
|
TRAIN = "../train.csv"
|
|
MODEL_NAME = "hfl/chinese-bert-wwm-ext"
|
|
SAVE_DIR = "ckpt"
|
|
MAX_LEN = 128
|
|
BATCH = 32
|
|
EPOCHS = 3
|
|
LR = 3e-5
|
|
SEED = 42
|
|
VAL_RATIO = 0.02
|
|
|
|
LABELS = ["B", "M", "E", "S"]
|
|
L2I = {l: i for i, l in enumerate(LABELS)}
|
|
PAD_ID = -100
|
|
|
|
random.seed(SEED); torch.manual_seed(SEED)
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
def tag_word(w):
|
|
if len(w) == 1:
|
|
return ["S"]
|
|
return ["B"] + ["M"] * (len(w) - 2) + ["E"]
|
|
|
|
|
|
def load_pairs():
|
|
pairs = []
|
|
with open(TRAIN, encoding="utf-8") as f:
|
|
r = csv.reader(f); next(r)
|
|
for row in r:
|
|
if not row: continue
|
|
words = [w for w in row[0].strip().split(" ") if w]
|
|
if not words: continue
|
|
tags, chars = [], []
|
|
for w in words:
|
|
tags.extend(tag_word(w)); chars.extend(list(w))
|
|
pairs.append((chars, tags))
|
|
return pairs
|
|
|
|
|
|
class SegDS(Dataset):
|
|
def __init__(self, pairs, tok):
|
|
self.pairs, self.tok = pairs, tok
|
|
def __len__(self): return len(self.pairs)
|
|
def __getitem__(self, i):
|
|
chars, tags = self.pairs[i]
|
|
chars = chars[:MAX_LEN - 2]
|
|
tags = tags[:MAX_LEN - 2]
|
|
enc = self.tok(chars, is_split_into_words=True, truncation=True, max_length=MAX_LEN,
|
|
padding="max_length", return_tensors="pt")
|
|
input_ids = enc["input_ids"].squeeze(0)
|
|
attn = enc["attention_mask"].squeeze(0)
|
|
word_ids = enc.word_ids(batch_index=0)
|
|
labels = []
|
|
for wid in word_ids:
|
|
if wid is None:
|
|
labels.append(PAD_ID)
|
|
else:
|
|
labels.append(L2I[tags[wid]])
|
|
labels = torch.tensor(labels, dtype=torch.long)
|
|
return input_ids, attn, labels
|
|
|
|
|
|
class BertTagger(nn.Module):
|
|
def __init__(self, name, n_labels):
|
|
super().__init__()
|
|
self.bert = BertModel.from_pretrained(name)
|
|
self.drop = nn.Dropout(0.1)
|
|
self.cls = nn.Linear(self.bert.config.hidden_size, n_labels)
|
|
def forward(self, ids, attn):
|
|
out = self.bert(ids, attention_mask=attn).last_hidden_state
|
|
return self.cls(self.drop(out))
|
|
|
|
|
|
def seg_f1(gold_tags, pred_tags):
|
|
def to_segs(tags):
|
|
r = []; start = None
|
|
for i, t in enumerate(tags):
|
|
if t == "B":
|
|
if start is not None: r.append((start, i - 1))
|
|
start = i
|
|
elif t == "M":
|
|
if start is None: start = i
|
|
elif t == "E":
|
|
if start is None: start = i
|
|
r.append((start, i)); start = None
|
|
elif t == "S":
|
|
if start is not None: r.append((start, i - 1))
|
|
r.append((i, i)); start = None
|
|
if start is not None: r.append((start, len(tags) - 1))
|
|
return set(r)
|
|
g, p = to_segs(gold_tags), to_segs(pred_tags)
|
|
tp = len(g & p)
|
|
if tp == 0: return 0.0
|
|
P, R = tp / len(p), tp / len(g)
|
|
return 2 * P * R / (P + R)
|
|
|
|
|
|
def run_val(model, loader):
|
|
model.eval()
|
|
f1s = []
|
|
with torch.no_grad():
|
|
for ids, attn, labels in loader:
|
|
ids, attn = ids.to(device), attn.to(device)
|
|
logits = model(ids, attn)
|
|
preds = logits.argmax(-1).cpu().numpy()
|
|
labels = labels.numpy()
|
|
for p, l in zip(preds, labels):
|
|
gold = [LABELS[x] for x in l if x != PAD_ID]
|
|
mask = l != PAD_ID
|
|
pr = [LABELS[x] for x, m in zip(p, mask) if m]
|
|
if gold:
|
|
f1s.append(seg_f1(gold, pr))
|
|
return sum(f1s) / max(len(f1s), 1)
|
|
|
|
|
|
def main():
|
|
os.makedirs(SAVE_DIR, exist_ok=True)
|
|
tok = BertTokenizerFast.from_pretrained(MODEL_NAME)
|
|
pairs = load_pairs()
|
|
random.shuffle(pairs)
|
|
n_val = int(len(pairs) * VAL_RATIO)
|
|
val, train = pairs[:n_val], pairs[n_val:]
|
|
print(f"train={len(train)} val={len(val)}")
|
|
|
|
tr_loader = DataLoader(SegDS(train, tok), batch_size=BATCH, shuffle=True, num_workers=2, pin_memory=True)
|
|
va_loader = DataLoader(SegDS(val, tok), batch_size=BATCH * 2, num_workers=2, pin_memory=True)
|
|
|
|
model = BertTagger(MODEL_NAME, len(LABELS)).to(device)
|
|
opt = AdamW(model.parameters(), lr=LR, weight_decay=0.01)
|
|
total_steps = len(tr_loader) * EPOCHS
|
|
sched = get_linear_schedule_with_warmup(opt, int(0.1 * total_steps), total_steps)
|
|
loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_ID)
|
|
scaler = torch.amp.GradScaler("cuda")
|
|
|
|
best = 0.0
|
|
step = 0
|
|
for epoch in range(EPOCHS):
|
|
model.train()
|
|
t0 = time.time()
|
|
running = 0.0
|
|
for ids, attn, labels in tr_loader:
|
|
ids, attn, labels = ids.to(device), attn.to(device), labels.to(device)
|
|
opt.zero_grad()
|
|
with torch.amp.autocast("cuda", dtype=torch.float16):
|
|
logits = model(ids, attn)
|
|
loss = loss_fn(logits.view(-1, len(LABELS)), labels.view(-1))
|
|
scaler.scale(loss).backward()
|
|
scaler.unscale_(opt)
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
|
scaler.step(opt); scaler.update(); sched.step()
|
|
running += loss.item(); step += 1
|
|
if step % 200 == 0:
|
|
print(f" ep{epoch+1} step{step} loss={running/200:.4f} lr={sched.get_last_lr()[0]:.2e}")
|
|
running = 0.0
|
|
f1 = run_val(model, va_loader)
|
|
dt = time.time() - t0
|
|
print(f"[epoch {epoch+1}] val F1={f1:.4f} time={dt:.1f}s")
|
|
if f1 > best:
|
|
best = f1
|
|
torch.save(model.state_dict(), f"{SAVE_DIR}/best.pt")
|
|
tok.save_pretrained(SAVE_DIR)
|
|
with open(f"{SAVE_DIR}/labels.json", "w") as f:
|
|
json.dump(LABELS, f)
|
|
print(f" saved best -> {SAVE_DIR}/best.pt")
|
|
print(f"best val F1 = {best:.4f}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|