Tag: jax

Tutorial: Demystifying Deep Learning for Data Scientists

Tutorial: Demystifying Deep Learning for Data Scientists

In this great tutorial for PyCon 2020, Eric Ma proposes a very simple framework for machine learning, consisting of only three elements:

  1. Model
  2. Loss function
  3. Optimizer

By adjusting the three elements in this simple framework, you can build any type of machine learning program.

In the tutorial, Eric shows you how to implement this same framework in Python (using jax) and implement linear regression, logistic regression, and artificial neural networks all in the same way (using gradient descent).

I can’t even begin to explain it as well as Eric does himself, so I highly recommend you watch and code along with the Youtube tutorial (~1 hour):

If you want to code along, here’s the github repository: github.com/ericmjl/dl-workshop

Have you ever wondered what goes on behind the scenes of a deep learning framework? Or what is going on behind that pre-trained model that you took from Kaggle? Then this tutorial is for you! In this tutorial, we will demystify the internals of deep learning frameworks – in the process equipping us with foundational knowledge that lets us understand what is going on when we train and fit a deep learning model. By learning the foundations without a deep learning framework as a pedagogical crutch, you will walk away with foundational knowledge that will give you the confidence to implement any model you want in any framework you choose.

https://www.youtube.com/watch?v=gGu3pPC_fBM