14/08/2019, 21:28

Giới thiệu về Connectionist Temporal Classification (CTC)

Trong nhận diện giọng nói hay nhận diện chữ viết tay , đầu ra sẽ là một câu, nhưng chưa hoàn chỉnh vì có các ký tự lặp lại như "heelllo", "toooo", ... hay các chữ có những khoảng trống (blank) như "he l l oo", "t o o", ... . Nguyên nhân dẫn tới những hiện ...

Trong nhận diện giọng nói hay nhận diện chữ viết tay, đầu ra sẽ là một câu, nhưng chưa hoàn chỉnh vì có các ký tự lặp lại như "heelllo", "toooo", ... hay các chữ có những khoảng trống (blank) như "he l l oo", "t o o", ... . Nguyên nhân dẫn tới những hiện tượng này là giọng nói dài (các đoạn ngân nga trong các bài hát, ...), giọng bị ngắt quãng, kích thước chữ viết tay lớn, nhỏ, ...

Do đó, để cho ra được một câu hoàn chỉnh thì ta cần phải căn chỉnh lại đầu ra ấy, loại bỏ các ký tự lặp lại và khoảng trống. Vấn đề này được gọi là alignment problem và nó được giải quyết bằng CTC.

Đầu tiên ta sẽ nói về temporal classification.

Gọi SSS là tập huấn luyện gồm các mẫu từ một phân phối cố định DX×ZD_{X imes Z}DX×Z, trong đó:

  • X=(Rm)∗X = (mathbb{R}^m)^*X=(Rm) là một tập gồm tất cả các chuỗi (sequence) của các vectors số thực có độ dài là mmm.

  • Z=L∗Z = L^*Z=L là một tập gồm tất cả các chuỗi chỉ có các ký tự alphabet có hạn LLL của các labels. Các phần tử của L∗L^*L còn được gọi là label sequences hay labellings.

Mỗi mẫu trong SSS bao gồm một cặp chuỗi (x,z)(x,z)(x,z). Chuỗi mục tiêu (target sequence) z=(z1,z2,...,zU)z = (z_1, z_2, ..., z_U)z=(z1,z2,...,zU) có độ dài gần như chuỗi đầu vào (input sequence) x=(x1,x2,...,xT)x = (x_1, x_2, ..., x_T)x=(x1,x2,...,xT), tức là U≤TU le TUT.

Mục tiêu của ta là dùng SSS để huấn luyện bộ temporal classifier h:X↦Zh: X mapsto Zh:XZ để phân loại các chuỗi đầu vào chưa thấy trước đó (previously unseen input sequences) sao cho tối thiểu hóa một lượng mất mát (error measure) nào đó.

Một trong các error measure là label error rate (LER) được tính bằng trung bình của edit distance của đầu ra dự đoán h(x)h(x)h(x) và nhãn zzz trong tập S′S'S lấy từ SSS:

LER(h,S′)=1∣S′∣∑(x,z)∈S′ED(h(x),z)∣z∣ ext{LER}(h, S') = frac{1}{lvert S' lvert} sum_{(x,z) in S'} frac{ ext{ED}(h(x), z)}{lvert z lvert} LER(h,S)=S1(x,z)SzED(h(x),z)

Edit distance ED(p,q) ext{ED}(p,q)ED(p,q) là số nhỏ nhất của số thêm vào (insertions), số thay thế (substitutions) và số xóa đi (deletions) để chuyển từ ppp sang qqq.

Vậy, temporal classifier hhh là một mô hình phân loại nào đó mang tính thời gian (temporal). Và cũng vì thế, ở đây ta sẽ lấy RNN làm một temporal classifier cho gần gũi (thật ra ta có thể lấy bất kỳ model nào cho ra output theo thời gian).

Thế thì, connectionist ở đâu ra? Để trả lời cho điều đó, ta hãy qua cách thức hoạt động của CTC.

Trước tiên, ta cần phải xác định đầu ra của RNN sao cho RNN là một temporal classifier.

Đầu ra của CTC Network: RNN

Như ta đã nói ở trên, LLL bao gồm các ký tự trong bảng chữ cái của ngôn ngữ ta đang sử dụng, ZZZ là tất cả các labellings có thể có.

Đầu ra của một CTC Network là kết quả của lớp softmax có số unit bằng số ký tự trong LLL cộng thêm một ký tự trống (blank), nghĩa là xác suất phân loại của các ký tự trong ∣L∣+1|L| + 1L+1 tại một thời điểm nhất định.

Như vậy, tổ hợp của tất cả các đầu ra theo thời gian là ZZZ. Dưới đây là hình minh họa.

Xác suất của một labelling

Xác suất của một labelling là tổng xác suất của tất cả các alignment cho ra labelling đó.

Gọi yyy là một chuỗi của đầu ra của network, trong đó ykty_k^tykt là kết quả của unit (ký tự - label) kkk tại thời điểm ttt. Khi đó ykty_k^tykt thuộc phân phối trên tập L′TL'^TLT của các chuỗi có độ dài TTT trên tập chữ cái L′=L∪{blank}L' = L cup { ext{blank}}L=L{blank}. Và xác suất của một alignment như sau:

p(π∣x)=∏t=1Tyπtt,∀x∈L′Tp(pi lvert x) = prod_{t=1}^T y_{pi_t}^t, forall x in L'^T p(πx)=t=1Tyπtt,xLT

Trong đó πpiπpath hay còn gọi là alignment. Và công thức ở trên giả định rằng đầu ra của network tại các thời điểm là độc lập với nhau.

Tiếp theo, ta định nghĩa một many-to-one map B:L′T↦L≤TB: L'^T mapsto L^{le T}B:LTLT bằng cách loại bỏ các ký tự blank và ký tự lặp lại (B(a−b−b)=B(aa−bb−b)=abb)(B(a-b-b) = B(aa-bb-b) = abb)(B(abb)=B(aabbb)=abb) và dùng BBB để tính xác suất của một labelling lll trong ZZZ:

p(l∣x)=∑π∈B−1(l)p(π∣x)p(l lvert x) = sum_{pi in B^{-1}(l)} p(pi lvert x) p(lx)=πB1(l)p(πx)

Xây dựng bộ phân loại

Từ công thức ở trên, đầu ra của bộ phân loại sẽ là labelling có vẻ đúng nhất.

h(x)=argmax⁡l∈L≤Tp(l∣x)h(x) = arg max_{l in L^{le T}} p(l lvert x) h(x)=arglLTmaxp(lx)

Vậy làm sao để tìm h(x)h(x)h(x)?

Trong bài báo này, họ đã đưa ra hai phương pháp:

  • Best path decoding: h(x)=B(π⋆)h(x) = B(pi^star)h(x)=B(π) trong đó π⋆=argmax⁡π∈Ntp(π∣x)pi^star = arg max_{pi in N^t} p(pi lvert x)π=argmaxπNtp(πx) là sự kết hợp của các unit có xác suất cao nhất của mỗi time-step, do đó không đảm bảo sẽ tìm thấy labelling đúng nhất.
  • Prefix search decoding (PSD): phương pháp này dựa trên forward-backward algorithm, nếu có đủ thời gian, PSD sẽ luôn tìm thấy labelling phù hợp nhất, nhưng số lượng prefix tối đa sẽ tăng theo hàm mũ, phức tạp nên phải áp dụng heuristic.

Xây dựng hàm mất mát (CTC Loss function)

Ta sẽ xây dựng hàm mất mát để có thể train bằng gradient descent. Hàm mất mát được lấy theo maximum likelihood. Nghĩa là khi tối thiểu hóa nó thì sẽ cực đại hóa log likelihood.

Như vậy hàm mất mát (hàm mục tiêu) sẽ là negative log likelihood:

OML(S,Nw)=−∑(x,z)∈Sln(p(l∣x))O^{ML}(S, N_w) = - sum_{(x,z) in S} ln(p(l lvert x)) OML(S,Nw)=(x,z)Sln(p(lx))

Như vậy, một CTC Network chẳng qua là một network phân loại thông thường có output theo thời gian (temporal classifier), ta tính toán xác suất của các alignments bằng cách connect các xác suất của output của các thời điểm lại với nhau và chọn alignment phù hợp nhất, tính sai số của nó và cho network học lại. Vì thế họ gọi là Connectionist Temporal Classification.

Cách áp dụng các giải thuật decoding để tìm h(x)h(x)h(x) và tính hàm mất mát thì mình sẽ nói trong bài viết sau nhé.

  1. https://www.cs.toronto.edu/~graves/icml_2006.pdf
  2. https://towardsdatascience.com/intuitively-understanding-connectionist-temporal-classification-3797e43a86c
0