Vanishing & Exploding Gradients Problems in Deep Neural Networks (Part 1)
Introduction Trong quá trình training Deep Neural Networks (DNNs), trong một số trường hợp thời gian training có thể kéo dài hay kết quả trả về có độ chính xác không như mong muốn. Một trong những nguyên nhân gây nên những hiện tượng trên có liên quan đến Gradients , hay cụ thể hơn là Vanishing ...
Introduction
Trong quá trình training Deep Neural Networks (DNNs), trong một số trường hợp thời gian training có thể kéo dài hay kết quả trả về có độ chính xác không như mong muốn. Một trong những nguyên nhân gây nên những hiện tượng trên có liên quan đến Gradients, hay cụ thể hơn là Vanishing / Exploding Gradients. Trong bài viết này, chúng ta sẽ tìm hiểu về hiện tượng trên cũng như đưa ra một số phương pháp giải quyết cụ thể để loại bỏ hiện tượng đó. Bài viết sẽ được chia thành hai phần tương ứng với hai mục đích trên.
Gradients
Trước khi tìm hiểu về vấn đề trên, chúng ta sẽ dành một chút thời gian để nhắc lại một số kiến thức cơ bản về Gradient (bạn có thể bỏ qua phần này nếu đã nắm chắc kiến thức hoặc đã quá bá đạo về Calculus). Gradient thực chất chính là đạo hàm của một hàm số được dùng để biểu diễn tỉ lệ thay đổi (sự biến thiên) của một hàm số tại một điểm nào đó. Nó là một vector với hai tính chất chính sau:
- Hướng của vector sẽ hướng theo chiều tăng của hàm.
- Giá trị của Gradient sẽ là 0 tại các điểm cực tiểu (local minimum) hay cực đại (local maximum).
Gradient thường được dùng với các hàm số có nhiều biến (multivariable functions), với các hàm có một biến ta thường dùng khá niệm Derivative. Bạn có thể nghe đến gradients như một cách để biểu thị độ dốc (slope) của hàm số, tuy nhiên trên thực tế nó cũng không hoàn toàn chính xác.
Xét hàm số một biến f(x)f(x)f(x), đạo hàm của hàm số đó d(f)/d(x)d(f) / d(x)d(f)/d(x) biểu diễn sự thay đổi của hàm fff khi xxx thay đổi. Đối với hàm nhiều biến, giả sử f(x,y)f(x, y)f(x,y) chúng ta sẽ có hai derivative cho nó d(f)/d(x)d(f) / d(x)d(f)/d(x) và d(f)/d(y) d(f) / d(y)d(f)/d(y), biểu diễn sự thay đổi của fff khi xxx hoặc yyy thay đổi. Sự biến thiên nhiều chiều nói trên có thể được biểu diễn bằng một vector với mỗi component tương ứng với một derivative:
[d(f)d(x)d(f)d(y)]egin{bmatrix} dfrac{d(f)}{d(x)} & dfrac{d(f)}{d(y)} end{bmatrix} [d(x)d(f)d(y)d(f)]
Đối với hàm số có nhiều biến thì vector trên sẽ có nhiều component tương ứng với các derivative của hàm đó. Để dễ hình dung, đối với hàm một biến Gradient sẽ giống như di chuyển backward và forward trên trục xxx. Với hàm số 2 biến thì chúng ta sẽ di chuyển trên một mặt phẳng, hàm số 3 biến thì sẽ di chuyển trong không gian 3 chiều,...
Gradient của một hàm số là đạo hàm của hàm số đó tương ứng với mỗi biến của hàm. Đối với hàm số đơn biến, chúng ta sẽ sử dụng khái niệm Derivative thay cho Gradient (trên thực tế chúng là một tuy nhiên chúng ta cần nhớ điều này khi đọc các tài liệu tiếng Anh).
Xét một hàm số 3 biến f(x,y,z)f(x, y, z)f(x,y,z), gradient của hàm số này sẽ được biểu diễn như sau:
grad f(x,y,z)=∇f(x,y,z)=[d(f)d(x)d(f)d(y)d(f)d(z)]gradspace f(x, y, z) = abla f(x, y, z) = egin{bmatrix} dfrac{d(f)}{d(x)} & dfrac{d(f)}{d(y)} & dfrac{d(f)}{d(z)} end{bmatrix} grad f(x,y,z)=∇f(x,y,z)=[d(x)d(f)d(y)d(f)d(z)d(f)]
Gradient của hàm số trên là một vector với 3 thành phần. Mỗi thành phần được gọi là đạo hàm riêng (đạo hàm từng phần hay Partial Derivative) tương ứng với một biến nào đó. Đạo hàm riêng thường được kí hiệu δf/δxdelta f / delta xδf/δx. Như đã đề cập ở trên, gradient cho chúng ta biết hướng di chuyển theo chiều tăng của hàm số tại một điểm nào đó. Hãy cùng xét một ví dụ cụ thể:
f(x,y,z)=x2+y3+z4∇f(x,y,z)=(2x,3y2,4z3)egin{aligned} &f(x, y, z) = x^2 + y^3 + z^4 & abla f(x, y, z) = (2x, 3y^2, 4z^3) end{aligned} f(x,y,z)=x2+y3+z4∇f(x,y,z)=(2x,3y2,4z3)
Giả sử chúng ta đang ở một điểm có tọa độ (1,4,5)(1, 4, 5)(1,4,5) và muốn tìm hướng di chuyển theo chiều tăng của hàm số trên. Công việc của chúng ta là thay tọa độ của điểm trên vào gradient vector ở trên:
direction=∇f(1,4,5)=(2∗1,3∗42,4∗53)=(2,48,500) ext{direction} = abla f(1, 4, 5) = (2*1, 3*4^2, 4*5^3) = (2, 48, 500) direction=∇f(1,4,5)=(2∗1,3∗42,4∗53)=(2,48,500)
Gradient sẽ hướng theo chiều tăng của một hàm số, đi theo hướng của gradient chúng ta sẽ tìm được một local maximum
Ứng dụng phổ biến của gradient là tìm các điểm cực đại hoặc cực tiểu của hàm số (có thể có ràng buộc)
Trong một số phương pháp tối ưu như Gradient Descent chúng ta sẽ thường phải tính toán các partial derivatives. Có khá nhiều cách để tính toán các giá trị trên, tuy nhiên một trong những cách phổ biến và hiệu quả nhất là Reverse-Mode Autodiff (được sử dụng trong TensorFlow và một số Deep Learning framework khác). Nội dung của phương pháp này vượt quá giới hạn của bài viết nên mình sẽ không trình bày chi tiết ở đây. Ý tưởng chung là chúng ta sẽ xây dựng Computation Graph cho hàm số và tính toán trên graph đó. Trong bước duyệt forward từ inputs đến outputs, chúng ta sẽ tính toán giá trị của các nodes trong graph. Sau đó chúng ta sẽ duyệt backward từ outputs đến các inputs và tính toán các partial derivatives. Reverse-Mode Autodiff sử dụng Chain Rule để tính toán partial derivative dựa vào các node liền kề nhau, cho đến khi chúng ta gặp các variable nodes.
δ(f)δ(x)=δ(f)δ(nodei)∗δ(nodei)δ(x)oxed { dfrac{delta (f)}{delta (x)} = dfrac{delta (f)}{delta (node_i)} * dfrac{delta (node_i)}{delta (x)} } δ(x)δ(f)=δ(nodei)δ(f)∗δ(x)δ(nodei)
Đây là một phương pháp khá hiệu quả và chính xác khi tính toán partial derivative đặc biệt khi chúng ta có nhiều inputs và ít outputs do nó yêu cầu duy nhất một lần duyệt forward để tính toán giá trị và nnn lần duyệt backward để tính toán partial derivative cho các outputs (nnn là số lượng outputs). Một điểm nữa là phương pháp này có thể được sử dụng để tính toán trên các hàm số có cấu trúc linh hoạt hoặc không có đạo hàm toàn phần.
Vanishing / Exploding Gradient Problems
Backpropagation Algorithm (thuật toán lan truyền ngược) là một kĩ thuật thường được sử dụng trong trong quá trình training DNNs. Ý tưởng chung của thuật toán là sẽ đi từ output layer đến input layer và tính toán gradient của cost function tương ứng cho từng parameter (weight) của network. Gradient Descent, sau đó, sẽ được sử dụng để cập nhật các parameter đó.
Quá trình trên sẽ được lặp lại cho tới khi các parameter của network hội tụ. Thông thường chúng ta sẽ có một hyperparameter định nghĩa cho số lượng vòng lặp để thực hiện quá trình trên. Hyperparameter đó thường được gọi là số Epoch (hay số lần mà training set được duyệt qua một lần và weights được cập nhật). Nếu số lượng vòng lặp quá nhỏ, DNN có thể sẽ không cho ra kết quả tốt, và ngược lại thì thời gian training sẽ quá dài nếu số lượng vòng lặp quá lớn. Ở đây ta có một tradeoff giữa độ chính xác và thời gian training.
Tuy nhiên trên thực tế gradients thường sẽ có giá trị nhỏ dần khi đi xuống các layer thấp hơn. Kết quả là các cập nhật thực hiện bởi Gradient Descent không làm thay đổi nhiều weights của các layer đó, khiến chúng không thể hội tụ và DNN sẽ không thu được kết quả tốt. Hiện tượng này được gọi là Vanishing Gradients.
Trong hình vẽ minh họa trên, cost function có dạng đường cong dẹt, chúng ta sẽ cần khá nhiều lần cập nhật (Gradient Descent step) để tìm được điểm global minimum.
Trong nhiều trường hợp khác, gradients có thể có giá trị lớn hơn trong quá trình backpropagation, khiến một số layers có giá trị cập nhật cho weights quá lớn khiến chúng phân kỳ (phân rã), tất nhiên DNN cũng sẽ không có kết quả như mong muốn. Hiện tượng này được gọi là Exploding Gradients, và thường gặp khi sử dụng Recurrent Neural Networks (RNNs).
Chung quy lại, trong quá trình training DNN chúng ta có thể gặp phải các vấn đề liên quan đến việc gradients không ổn định khiến cho tốc độ học của các layer khác nhau chênh lệch khá nhiều.
Hai hiện tượng trên là một trong những nguyên nhân khiến neural networks không nhận được sự quan tâm trong một thời gian khá dài. Tuy nhiên trong một nghiên cứu được thực hiện bởi Xavier Glorot và Yoshua Bengio năm 2010 (tham khảo trong references), các tác giả đã đưa ra một số nguyên nhân dẫn đến hiện tượng trên. Trong đó việc lựa chọn activation function và kỹ thuật weight initialization là hai nguyên nhân chính.
Một trong những hàm kích hoạt phi tuyến khá phổ biến trong những giai đoạn đầu của neural networks là logistic sigmoid activation function; tuy nhiên hàm này có một số nhược điểm khiến quá trình training neural networks gặp nhiều khó khăn (chúng ta sẽ tìm hiểu trong phần sau của bài viết).
Về kỹ thuật weight initialization, random initialization sử dụng phân phối chuẩn (normal distribution) với kỳ vọng (mean) là 0 và độ lệch chuẩn (standard deviation) là 1. Chung quy lại, các tác giả cho thấy việc sử dụng sigmoid activation function cùng với random initialization khiến cho phương sai của các outputs của mỗi layer lớn hơn khá nhiều so với phương sai của inputs cho layer đó. Trong chiều đi xuôi của networks (forward), các giá trị phương sai sẽ tăng dần và hàm kích hoạt sẽ trở nên bão hòa ở những layer phía trên.
References:
- Understanding the difficulty of training deep feedforward neural networks
Sigmoid Activation Function
Trong phần này chúng ta sẽ cùng tìm hiểu về Sigmoid function, một hàm kích hoạt phi tuyến đã từng khá phổ biến trong DNN. Chúng ta sẽ bắt đầu bằng định nghĩa:
Sigmoid(z)=f(z)=11+e−zf′(z)=f(z)(1−f(z))=11+e−z[1−11+e−z]oxed { egin{aligned} &Sigmoid(z) = f(z) = dfrac{1}{1 + e^{-z}} &f'(z) = f(z)(1 - f(z)) = {dfrac{1}{1 + e^{-z}}} egin{bmatrix}{1 - {dfrac{1}{1 + e^{-z}}}} end{bmatrix} end{aligned} } Sigmoid(z)=f(z)=1+e−z1f′(z)=f(z)(1−f(z))=1+e−z1[1−1+e−z1]
OK, chúng ta sẽ plot sigmoid function và derivative của nó:
def sigmoid(z): return 1 / (1 + np.exp(-z)) z = np.linspace(-5, 5, 200) plt.figure(figsize=(6, 4)) plt.plot([-5, 5], [0, 0], 'k-') plt.plot([-5, 5], [1, 1], 'k--') plt.plot([0, 0], [-0.2, 1.2], 'k-') plt.plot([-5, 5], [-3 / 4, 7 / 4], 'g--') plt.plot(z, sigmoid(z), "b-", lineawidth=2) props = dict(facecolor='black', shrink=0.1) plt.annotate('Saturating Point', xytext=(3.5, 0.7), xy=(5, 1), arrowprops=props, fontsize=14, ha="center") plt.annotate('Saturating Point', xytext=(-3.5, 0.3), xy=(-5, 0), arrowprops=props, fontsize=14, ha="center") plt.annotate('Linear', xytext=(2, 0.2), xy=(0, 0.5), arrowprops=props, fontsize=14, ha="center") plt.grid(True) plt.title("Sigmoid activation function", fontsize=14) plt.axis([-5, 5, -0.2, 1.2]) plt.show()
def sigmoid_derivative(z): return sigmoid(z) * (1 - sigmoid(z)) z = np.linspace(-10, 10, 200) plt.figure(figsize=(8, 4)) plt.plot([-10, 10],