def linear_backward(dZ, cache):
A_prev, W, b = cache
m = A_prev.shape[1]
dW, db, dA_prev = None, None, None
dW = (1/m) * np.dot(dZ, cache[0].T) # cache[0] is A_prev or A[l-1]
db = (1/m) * np.sum(dZ, axis = 1, keepdims = True)
dA_prev = np.dot(cache[1].T, dZ) # here we use use W, actually W^T
assert (dA_prev.shape == A_prev.shape)
assert (dW.shape == W.shape)
assert (db.shape == b.shape)
return dA_prev, dW, db