73 lines
2.1 KiB
Python
73 lines
2.1 KiB
Python
"""HMM 分词器:BMES 标签 + Laplace 平滑,模型存为 json。"""
|
||
import csv, json, math
|
||
from collections import defaultdict
|
||
|
||
TRAIN = "train.csv"
|
||
MODEL = "hmm.json"
|
||
STATES = ("B", "M", "E", "S")
|
||
|
||
|
||
def tag_word(w):
|
||
if len(w) == 1:
|
||
return ["S"]
|
||
return ["B"] + ["M"] * (len(w) - 2) + ["E"]
|
||
|
||
|
||
def main():
|
||
init = defaultdict(float)
|
||
trans = {s: defaultdict(float) for s in STATES}
|
||
emit = {s: defaultdict(float) for s in STATES}
|
||
total_lines = 0
|
||
|
||
with open(TRAIN, encoding="utf-8") as f:
|
||
reader = csv.reader(f)
|
||
next(reader)
|
||
for row in reader:
|
||
if not row:
|
||
continue
|
||
words = [w for w in row[0].strip().split(" ") if w]
|
||
if not words:
|
||
continue
|
||
tags, chars = [], []
|
||
for w in words:
|
||
tags.extend(tag_word(w))
|
||
chars.extend(list(w))
|
||
init[tags[0]] += 1
|
||
for i, (c, t) in enumerate(zip(chars, tags)):
|
||
emit[t][c] += 1
|
||
if i > 0:
|
||
trans[tags[i - 1]][t] += 1
|
||
total_lines += 1
|
||
|
||
init_total = sum(init.values())
|
||
init_log = {s: math.log((init[s] + 1) / (init_total + len(STATES))) for s in STATES}
|
||
|
||
trans_log = {}
|
||
for s in STATES:
|
||
tot = sum(trans[s].values())
|
||
trans_log[s] = {
|
||
t: math.log((trans[s][t] + 1) / (tot + len(STATES))) for t in STATES
|
||
}
|
||
|
||
vocab = set()
|
||
for s in STATES:
|
||
vocab.update(emit[s].keys())
|
||
V = len(vocab) + 1
|
||
emit_log = {}
|
||
emit_default = {}
|
||
for s in STATES:
|
||
tot = sum(emit[s].values())
|
||
emit_log[s] = {c: math.log((emit[s][c] + 1) / (tot + V)) for c in emit[s]}
|
||
emit_default[s] = math.log(1 / (tot + V))
|
||
|
||
with open(MODEL, "w", encoding="utf-8") as f:
|
||
json.dump(
|
||
{"init": init_log, "trans": trans_log, "emit": emit_log, "emit_default": emit_default},
|
||
f, ensure_ascii=False,
|
||
)
|
||
print(f"trained on {total_lines} sentences, vocab={len(vocab)}, saved {MODEL}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|