163 lines
4.1 KiB
Python
163 lines
4.1 KiB
Python
"""用 HMM + 最大匹配词典混合的方式分词。
|
||
策略:先用高频词典最大匹配,对未匹配部分用 HMM Viterbi 分。
|
||
"""
|
||
import csv, json, sys
|
||
from collections import defaultdict
|
||
|
||
MODEL = "hmm.json"
|
||
TEST = "test.csv"
|
||
OUT = "submission.csv"
|
||
FREQ = "chinese_word_freq_list.txt"
|
||
|
||
STATES = ("B", "M", "E", "S")
|
||
|
||
|
||
def load_model():
|
||
with open(MODEL, encoding="utf-8") as f:
|
||
return json.load(f)
|
||
|
||
|
||
def load_dict():
|
||
words = set()
|
||
with open(FREQ, encoding="utf-8") as f:
|
||
for line in f:
|
||
parts = line.strip().split()
|
||
if len(parts) >= 2:
|
||
w = parts[1]
|
||
if len(w) >= 2:
|
||
words.add(w)
|
||
return words
|
||
|
||
|
||
def viterbi(chars, model):
|
||
if not chars:
|
||
return []
|
||
init = model["init"]
|
||
trans = model["trans"]
|
||
emit = model["emit"]
|
||
emit_def = model["emit_default"]
|
||
|
||
V = [{}]
|
||
path = {}
|
||
for s in STATES:
|
||
e = emit[s].get(chars[0], emit_def[s])
|
||
V[0][s] = init[s] + e
|
||
path[s] = [s]
|
||
|
||
for t in range(1, len(chars)):
|
||
V.append({})
|
||
new_path = {}
|
||
for s in STATES:
|
||
e = emit[s].get(chars[t], emit_def[s])
|
||
best_p, best_prev = max(
|
||
(V[t - 1][p] + trans[p][s] + e, p) for p in STATES
|
||
)
|
||
# 限制:B/S 开头;E/S 结尾;B->M/E;M->M/E;E->B/S;S->B/S
|
||
V[t][s] = best_p
|
||
new_path[s] = path[best_prev] + [s]
|
||
path = new_path
|
||
|
||
# 末位必须是 E 或 S
|
||
best_p, best_s = max((V[-1][s], s) for s in ("E", "S"))
|
||
return path[best_s]
|
||
|
||
|
||
def seg_by_tags(chars, tags):
|
||
words = []
|
||
buf = ""
|
||
for c, t in zip(chars, tags):
|
||
if t == "B":
|
||
if buf:
|
||
words.append(buf)
|
||
buf = c
|
||
elif t == "M":
|
||
buf += c
|
||
elif t == "E":
|
||
buf += c
|
||
words.append(buf)
|
||
buf = ""
|
||
else: # S
|
||
if buf:
|
||
words.append(buf)
|
||
buf = ""
|
||
words.append(c)
|
||
if buf:
|
||
words.append(buf)
|
||
return words
|
||
|
||
|
||
def max_match(chars, word_set, max_len=6):
|
||
"""正向最大匹配返回 segments 列表,每个元素是 (start, end) 左闭右开。"""
|
||
n = len(chars)
|
||
i = 0
|
||
segs = []
|
||
while i < n:
|
||
matched = False
|
||
for L in range(min(max_len, n - i), 1, -1):
|
||
w = "".join(chars[i:i + L])
|
||
if w in word_set:
|
||
segs.append((i, i + L))
|
||
i += L
|
||
matched = True
|
||
break
|
||
if not matched:
|
||
i += 1
|
||
return segs
|
||
|
||
|
||
def hybrid_segment(sentence, model, word_set):
|
||
chars = list(sentence)
|
||
n = len(chars)
|
||
if n == 0:
|
||
return []
|
||
locked = max_match(chars, word_set)
|
||
# 在 locked 之外用 HMM
|
||
result = []
|
||
cur = 0
|
||
for a, b in locked:
|
||
if cur < a:
|
||
sub = chars[cur:a]
|
||
tags = viterbi(sub, model)
|
||
result.extend(seg_by_tags(sub, tags))
|
||
result.append("".join(chars[a:b]))
|
||
cur = b
|
||
if cur < n:
|
||
sub = chars[cur:]
|
||
tags = viterbi(sub, model)
|
||
result.extend(seg_by_tags(sub, tags))
|
||
return result
|
||
|
||
|
||
def transfer(words):
|
||
"""['我','爱','自然','语言'] -> '[0] [1] [2,3] [4,5]'"""
|
||
count = 0
|
||
out = []
|
||
for w in words:
|
||
idx = list(range(count, count + len(w)))
|
||
out.append(str(idx).replace(" ", ""))
|
||
count += len(w)
|
||
return " ".join(out)
|
||
|
||
|
||
def main():
|
||
model = load_model()
|
||
word_set = load_dict()
|
||
print(f"dict size={len(word_set)}")
|
||
|
||
with open(TEST, encoding="utf-8") as f, open(OUT, "w", encoding="utf-8", newline="") as out:
|
||
reader = csv.reader(f)
|
||
writer = csv.writer(out)
|
||
next(reader) # id,sentence
|
||
writer.writerow(["id", "expected"])
|
||
for row in reader:
|
||
if len(row) < 2:
|
||
continue
|
||
sid, sent = row[0], row[1]
|
||
words = hybrid_segment(sent, model, word_set)
|
||
writer.writerow([sid, transfer(words)])
|
||
print(f"wrote {OUT}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|