nlp/exp1_fenci/predict.py
2026-04-29 18:34:27 +08:00

163 lines
4.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""用 HMM + 最大匹配词典混合的方式分词。
策略:先用高频词典最大匹配,对未匹配部分用 HMM Viterbi 分。
"""
import csv, json, sys
from collections import defaultdict
MODEL = "hmm.json"
TEST = "test.csv"
OUT = "submission.csv"
FREQ = "chinese_word_freq_list.txt"
STATES = ("B", "M", "E", "S")
def load_model():
with open(MODEL, encoding="utf-8") as f:
return json.load(f)
def load_dict():
words = set()
with open(FREQ, encoding="utf-8") as f:
for line in f:
parts = line.strip().split()
if len(parts) >= 2:
w = parts[1]
if len(w) >= 2:
words.add(w)
return words
def viterbi(chars, model):
if not chars:
return []
init = model["init"]
trans = model["trans"]
emit = model["emit"]
emit_def = model["emit_default"]
V = [{}]
path = {}
for s in STATES:
e = emit[s].get(chars[0], emit_def[s])
V[0][s] = init[s] + e
path[s] = [s]
for t in range(1, len(chars)):
V.append({})
new_path = {}
for s in STATES:
e = emit[s].get(chars[t], emit_def[s])
best_p, best_prev = max(
(V[t - 1][p] + trans[p][s] + e, p) for p in STATES
)
# 限制B/S 开头E/S 结尾B->M/EM->M/EE->B/SS->B/S
V[t][s] = best_p
new_path[s] = path[best_prev] + [s]
path = new_path
# 末位必须是 E 或 S
best_p, best_s = max((V[-1][s], s) for s in ("E", "S"))
return path[best_s]
def seg_by_tags(chars, tags):
words = []
buf = ""
for c, t in zip(chars, tags):
if t == "B":
if buf:
words.append(buf)
buf = c
elif t == "M":
buf += c
elif t == "E":
buf += c
words.append(buf)
buf = ""
else: # S
if buf:
words.append(buf)
buf = ""
words.append(c)
if buf:
words.append(buf)
return words
def max_match(chars, word_set, max_len=6):
"""正向最大匹配返回 segments 列表,每个元素是 (start, end) 左闭右开。"""
n = len(chars)
i = 0
segs = []
while i < n:
matched = False
for L in range(min(max_len, n - i), 1, -1):
w = "".join(chars[i:i + L])
if w in word_set:
segs.append((i, i + L))
i += L
matched = True
break
if not matched:
i += 1
return segs
def hybrid_segment(sentence, model, word_set):
chars = list(sentence)
n = len(chars)
if n == 0:
return []
locked = max_match(chars, word_set)
# 在 locked 之外用 HMM
result = []
cur = 0
for a, b in locked:
if cur < a:
sub = chars[cur:a]
tags = viterbi(sub, model)
result.extend(seg_by_tags(sub, tags))
result.append("".join(chars[a:b]))
cur = b
if cur < n:
sub = chars[cur:]
tags = viterbi(sub, model)
result.extend(seg_by_tags(sub, tags))
return result
def transfer(words):
"""['','','自然','语言'] -> '[0] [1] [2,3] [4,5]'"""
count = 0
out = []
for w in words:
idx = list(range(count, count + len(w)))
out.append(str(idx).replace(" ", ""))
count += len(w)
return " ".join(out)
def main():
model = load_model()
word_set = load_dict()
print(f"dict size={len(word_set)}")
with open(TEST, encoding="utf-8") as f, open(OUT, "w", encoding="utf-8", newline="") as out:
reader = csv.reader(f)
writer = csv.writer(out)
next(reader) # id,sentence
writer.writerow(["id", "expected"])
for row in reader:
if len(row) < 2:
continue
sid, sent = row[0], row[1]
words = hybrid_segment(sent, model, word_set)
writer.writerow([sid, transfer(words)])
print(f"wrote {OUT}")
if __name__ == "__main__":
main()