# Introduction

In this exercise, we implement the Hebbian Learning method proposed in [Meta-Learning through Hebbian Plasticity in Random Networks](https://arxiv.org/abs/2007.02686).

In [None]:
# @title Import libraries

# Numpy is all you need ;)
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
from IPython.display import clear_output, HTML
import imageio
from base64 import b64encode

import warnings

try:
    import gym
except ModuleNotFoundError:
    !pip install gym
    clear_output()
    import gym

from multiprocessing import Pool

In [None]:
# @title Know the task
# @markdown We will be using the CartPole setting as our environment.

# @markdown Here is a reminder of the task. The episode ends if any one of the following occurs:
# @markdown 1. Termination: Pole Angle is greater than ±12°
# @markdown 2. Termination: Cart Position is greater than ±2.4 (center of the cart reaches the edge of the display)
# @markdown 3. Truncation: Episode length is greater than 500 (200 for v0)

warnings.filterwarnings('ignore', category=DeprecationWarning)


def play_video(image_list, fps=30):
    output_video = '/tmp/temp_video.mp4'
    with imageio.get_writer(output_video, fps=fps) as writer:
        for img in image_list:
            writer.append_data(img)
    mp4 = open(output_video,'rb').read()
    data_url = 'data:video/mp4;base64,' + b64encode(mp4).decode()
    return HTML("""
    <video width=400 controls loop>
        <source src="%s" type="video/mp4">
    </video>
    """ % data_url)


def rand_control(seed):
    env = gym.make('CartPole-v1')
    obs = env.reset(seed=seed)
    total_reward = 0
    done = False
    imgs = [env.render('rgb_array')]
    while not done:
        action = np.random.randint(0, 2)
        obs, reward, done, info = env.step(action)
        total_reward += reward
        imgs.append(env.render('rgb_array'))
    print(f'reward={total_reward}')
    return imgs


seed = 42  # @param
images = rand_control(seed)
play_video(images)

In [None]:
# @title Neural network definition
# @markdown We use a feedforward neural network (a.k.a MLP) to control the cart.

from typing import List, Optional


class CartPoleControl(object):
    """A feedforward neural network."""

    def __init__(self,
                 input_dim: int,
                 hidden_dims: List[int]):
        """Initialization.

        Arguments:
          input_dim     - Input dimension. 4 in our CartPole example.
          hidden_dims   - Hidden dimension. E.g., [32, 32].
        """
        self.input_dim = input_dim
        self.output_dim = 1
        self.w_sizes = []
        self.num_params = 0

        # Paramter sizes for the hidden layers.
        dim_in = input_dim
        for hidden_dim in hidden_dims:
            self.w_sizes.append((dim_in, hidden_dim))
            self.num_params += dim_in * hidden_dim
            dim_in = hidden_dim

        # Parameters for the output layer.
        self.w_sizes.append((dim_in, self.output_dim))
        self.num_params += dim_in * self.output_dim
        print(f'#params={self.num_params}')

    def seed(self, seed):
        pass

    def __call__(self,
                 params: np.ndarray,
                 obs: np.ndarray):
        """This is a batch forward function.

        Arguments:
          params    - Network parameters of shape (M,), where M
                      is the parameters' size.
          obs       - Network input data of shape (input_dim,).

        Returns:
          Action of shape (output_dim,).
        """
        assert params.size == self.num_params, 'Inconsistent params sizes.'

        x = obs
        ss = 0
        for w_size in self.w_sizes:
            ee = ss + np.prod(w_size)
            w = params[ss:ee].reshape(w_size)
            x = np.tanh(np.einsum('i,ij->j', x, w))
            ss = ee
        assert ss == params.size

        return 0 if x < 0 else 1


def nn_control(policy, params, seed, render=True, break_in_middle=False):
    env = gym.make('CartPole-v1')
    obs = env.reset(seed=seed)
    policy.seed(seed=seed)
    total_reward = 0
    done = False
    step_to_break = np.random.randint(100, 300)
    step = 0
    if render:
        imgs = [env.render('rgb_array')]
    while not done:
        if break_in_middle and step_to_break == step:
            print(f'Reset MLP parameters at step {step_to_break}')
            param_hist = policy.mlp_params_hist
            policy.seed(seed=np.random.randint(1<<20))
            policy.mlp_params_hist = param_hist + policy.mlp_params_hist
        action = policy(params, obs)
        obs, reward, done, info = env.step(action)
        total_reward += reward
        step += 1
        if render:
            imgs.append(env.render('rgb_array'))
    env.close()
    if render:
        print(f'reward={total_reward}')
        return imgs
    else:
        return total_reward


hidden_dims = [32, 32]  # @param
model = CartPoleControl(input_dim=4, hidden_dims=hidden_dims)
rand_params = np.random.randn(model.num_params) * 0.01
images = nn_control(model, rand_params, seed)
print('Control the cart with a random policy')
play_video(images)

In [None]:
# @title Train with CMA-ES

warnings.filterwarnings('ignore', category=DeprecationWarning)

try:
    import cma
except ModuleNotFoundError:
    !pip install cma
    clear_output()
    import cma


# @markdown Feel free to play with the following parameters.
n_repeats = 10     # @param
num_gen = 10      # @param
pop_size = 32      # @param
init_stdev = 0.1  # @param
num_worker = 8     # @param


def eval_params(args):
    params, seed = args
    rewards = [nn_control(model, params, seed + 37*i, render=False)
               for i in range(n_repeats)]
    # We evaluate the parameters for multiple times to reduce noise.
    return np.mean(rewards)


# Initialize the CMA-ES solver.
algo = cma.CMAEvolutionStrategy(
    x0=np.zeros(model.num_params),
    sigma0=init_stdev,
    inopts={
        "popsize": pop_size,
        "seed": seed,
        "randn": np.random.randn,
    }
)


# Optimization loop, we use multiprocessing to accelerate the rollouts.
with Pool(num_worker) as p:
    for i in range(num_gen):
        population = algo.ask()
        rollout_seed = np.random.randint(0, 10000000)
        scores = p.map(eval_params,
                       [x for x in zip(
                           [np.array(population[k]) for k in range(pop_size)],
                            [rollout_seed] * pop_size)])
        algo.tell(population, [-x for x in scores])  # CMA-ES minimizes.
        print(f'Gen={i+1}, reward.max={np.max(scores)}')


# Test and visualize the trained control policy.
best_params = np.array(algo.result.xfavorite)
images = nn_control(model, best_params, seed)
play_video(images)

# Hebbian Learning

In the code above, we use evolutionary algorithm to directly optimize the parameters of the control policy. In Hebbian Learning, we instead optimize the Hebbian learning rule, which in turn optimizes the control policy online.

The Hebbian Learning rule is as the follows, where $w_{i,j}$ is the weight between neurons $i$ and $j$, $o_i$ and $o_j$ are the pre-/post-synapic activations.
$$\Delta w_{i,j} = \eta \cdot (A_{i,j} o_i o_j + B_i o_i + C_j o_j + D_{i,j} )$$

The rule above defines how the weights of a control network should change.
Assuming $w \in \mathbb{R}^{M \times N}$, the learnable Hebbian parameters are $\eta \in \mathbb{R}$, $A \in \mathbb{R}^{M \times N}$, $B \in \mathbb{R}^{M \times N}$, $C \in \mathbb{R}^{M \times N}$ and $D \in \mathbb{R}^{M \times N}$. \\

In [None]:
# @title Neural network with Hebbian learning (Implement this)


class MetaCartPoleControl(object):
    """A feedforward neural network, with Hebbian learning."""

    def __init__(self,
                 input_dim: int,
                 hidden_dims: List[int]):
        """Initialization.

        Arguments:
          input_dim     - Input dimension. 4 in our CartPole example.
          hidden_dims   - Hidden dimension. E.g., [32, 32].
        """
        self.input_dim = input_dim
        self.output_dim = 1

        self.w_sizes = []
        self.num_params = 0
        self.hebbian_w_sizes = []
        self.num_hebbian_params = 0

        dim_in = input_dim
        for hidden_dim in hidden_dims:
            self.w_sizes.append((dim_in, hidden_dim))
            self.num_params += dim_in * hidden_dim
            dim_in = hidden_dim
        self.w_sizes.append((dim_in, self.output_dim))
        self.num_params += dim_in * self.output_dim

        # Initialize self.hebbian_w_sizes and self.num_hebbian_params
        # Your code here

        print(f'#params={self.num_params}')
        print(f'#hebbian_params={self.num_hebbian_params}')

        # MLP parameters are randomly sampled from Uniform(-1, 1)
        self.mlp_params = np.random.rand(self.num_params) * 2 - 1
        self.mlp_params_hist = []

    def seed(self, seed):
        np_random = np.random.RandomState(seed)
        self.mlp_params = np_random.rand(self.num_params) * 2 - 1
        self.mlp_params_hist = [self.mlp_params.copy()]

    def __call__(self,
                 hebbian_params: np.ndarray,
                 obs: np.ndarray):
        """Apply Hebbian learning rule to the mlp_params and also return action.

        Arguments:
          hebbian_params    - Hebbian learning parameters of shape (M,), where M
                              is the parameters' size.
          obs               - Network input data of shape (input_dim,).

        Returns:
          Action of shape (output_dim,).
        """
        assert hebbian_params.size == self.num_hebbian_params, (
            'Inconsistent params sizes.'
        )

        x = obs
        ss = 0
        h_ss = 0
        for w_size, hebbian_w_size in zip(self.w_sizes, self.hebbian_w_sizes):

            # Pass data through the MLP layer
            ee = ss + np.prod(w_size)
            w = self.mlp_params[ss:ee].reshape(w_size)
            pre_synaptic_activation = x
            x = np.tanh(np.einsum('i,ij->j', x, w))
            post_synaptic_activation = x

            # Apply the Hebbian learning rule
            # Your code here

            ss = ee
            h_ss = h_ee

        assert ss == self.num_params
        assert h_ss == self.num_hebbian_params
        self.mlp_params_hist.append(self.mlp_params.copy())

        return 0 if x < 0 else 1


hebbian_model = MetaCartPoleControl(input_dim=4, hidden_dims=hidden_dims)
rand_hebbian_params = np.random.randn(hebbian_model.num_hebbian_params) * 0.01
images = nn_control(hebbian_model, rand_hebbian_params, seed)
print('Control the cart with a random Hebbian policy')
play_video(images)

In [None]:
# @title Train Hebbian with CMA-ES


def eval_hebbian_params(args):
    params, seed = args
    rewards = [nn_control(hebbian_model, params, seed + 37*i, render=False)
               for i in range(n_repeats)]
    # We evaluate the parameters for multiple times to reduce noise.
    return np.mean(rewards)


# Initialize the CMA-ES solver.
algo = cma.CMAEvolutionStrategy(
    x0=np.zeros(hebbian_model.num_hebbian_params),
    sigma0=init_stdev,
    inopts={
        "popsize": pop_size,
        "seed": seed,
        "randn": np.random.randn,
    }
)


# Optimization loop, we use multiprocessing to accelerate the rollouts.
with Pool(num_worker) as p:
    for i in range(num_gen * 5):
        population = algo.ask()
        rollout_seed = np.random.randint(0, 10000000)
        scores = p.map(eval_hebbian_params,
                       [x for x in zip(
                           [np.array(population[k]) for k in range(pop_size)],
                            [rollout_seed] * pop_size)])
        algo.tell(population, [-x for x in scores])  # CMA-ES minimizes.
        if i % 10 == 0:
            print(f'Gen={i+1}, reward.max={np.max(scores)}')


# Test and visualize the trained control policy.
best_params = np.array(algo.result.xfavorite)
images = nn_control(hebbian_model, best_params, seed)
play_video(images)

In [None]:
# @title Visualize how the MLP parameters have changed
# @markdown In this plot, y-axis shows the time steps in an episode. \\
# @markdown You can see a clear boundary where the parameters are reset to random values. \\
# @markdown Although we intialized the parameters in (-1, 1), their final values have grown much larger in absolute values.

images = nn_control(hebbian_model, best_params, seed, True, True)
play_video(images)

mlp_params_hist = np.array(hebbian_model.mlp_params_hist)
plt.figure(figsize=(10, 5))
plt.imshow(mlp_params_hist)
_ = plt.colorbar()