# Implementing a CPPN for Creative Pattern Generation

In this exercise, we'll explore the capabilities of Compositional Pattern Producing Networks (CPPNs), which are specially designed neural networks for generating complex and artistic patterns. CPPNs are distinct from traditional neural networks in their ability to produce patterns akin to natural phenomena. Our objective is to enable CPPNs to:
1. Generate patterns based on textual descriptions.
2. Replicate patterns from given images.




In [None]:
# @title Import Libraries and Load CLIP

import requests
import time
import numpy as np
from PIL import Image
import torch
from transformers import CLIPProcessor, CLIPModel
import jax
import jax.numpy as jnp
from IPython.display import display, clear_output
import matplotlib.pyplot as plt
try:
    import cma
except ModuleNotFoundError:
    !pip install cma
    import cma


model = CLIPModel.from_pretrained('openai/clip-vit-base-patch32')
processor = CLIPProcessor.from_pretrained('openai/clip-vit-base-patch32')
clear_output()

In [None]:
# @title Fitness Generating Function

image_size = 128  # @param {type:"integer"}
prompt = 'a pattern of snow flake' #@param {type:"string"}
match_target = 'prompt' # @param ["image", "prompt"]
match_text = match_target == 'prompt'

url = 'https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcTllSTuTTX30C13Gf16ah64qBp5BdBeC2ygMBpE1WECcw&s'
ref_img = Image.open(requests.get(url, stream=True).raw).convert('L').resize(
    (image_size, image_size))


def gen_fitness(batch_data):
    with torch.inference_mode():
        if match_text:
            inputs = processor(
                text=[prompt],
                images=[Image.fromarray(np.uint8(np.round(data)))
                        for data in batch_data],
                return_tensors="pt", padding=True)
            outputs = model(**inputs)
            ref_em = torch.nn.functional.normalize(outputs.text_embeds, dim=-1)
            img_em = torch.nn.functional.normalize(outputs.image_embeds, dim=-1)
            similarities = (img_em @ ref_em.T).squeeze()
            return similarities.numpy()
        else:
            ref_img_array = np.array(ref_img).astype(float)[None, ...]
            mse = np.mean(np.square(batch_data - ref_img_array).reshape(
                [-1, image_size**2]), axis=-1)
            return -mse


test_batch_size = 3
rand_images = np.random.rand(
    test_batch_size * image_size * image_size).reshape(
    [-1, image_size, image_size]) * 255
scores = gen_fitness(rand_images)

if match_text:
    display(f'Train CPPN to match the prompt: {prompt}')
    fig, axes = plt.subplots(
        1, test_batch_size, figsize=(3 * test_batch_size, 3))
    for i in range(test_batch_size):
        axes[i].imshow(np.uint8(np.round(rand_images[i])), cmap='gray')
        axes[i].axis('off')
        axes[i].set_title(f'Fitness: {scores[i]:.4f}')
    plt.tight_layout()
else:
    print('Train CPPN to match this reference image')
    display(ref_img)

# Approximate CPPN with MLP-like grids

While evolving a Compositional Pattern Producing Network (CPPN) using NeuroEvolution of Augmenting Topologies (NEAT) can be time-intensive, we can streamline this process. This exercise introduces a method to reduce the complexity and duration of CPPN evolution by using a specially designed MLP that mimics potential outcomes of the NEAT process.

The goal of this exercise is to implement a modified MLP that approximates the behavior of a CPPN with significantly reduced search space. This approach leverages varying activation functions and dynamic connectivity, driven by evolutionary principles, while maintaining a predefined network structure.


## Key Features of the Special MLP
- **Varying Activation Functions**: Each neuron within this MLP can have a different activation function, diverging from the traditional MLP structure where each layer usually has uniform activation functions. This feature allows the network to more closely simulate the diverse functional capabilities of CPPNs.

- **Evolutionary Connections**: Unlike in standard MLPs where the connectivity between neurons is fixed, this MLP allows connections to evolve over iterations. This adaptive connectivity is similar to the flexibility seen in NEAT, allowing the network to optimize its structure dynamically based on performance.

- **Fixed Architecture**: To reduce the complexity inherent in NEAT-evolved CPPNs, this MLP has a predetermined depth and width. The number of layers and neurons per layer are set from the beginning, significantly narrowing the search space and potentially reducing the time required for network evolution.


To speed up the evaluation speed, we are going to use JAX in this exercise. Don't worry if you are new to JAX, it resembles numpy a lot. Plus, we provide sample code for your reference in the following cell.

In [None]:
# @title Encode Computational Graph

x1, x2 = jnp.meshgrid(jnp.arange(image_size), jnp.arange(image_size))
xy = jnp.stack([x1.ravel(), x2.ravel()], axis=-1).astype(jnp.float32)
d = jnp.sqrt(jnp.sum(jnp.square(xy - (image_size + 1) / 2), axis=-1))
bias = jnp.ones_like(d)
xydb = jnp.concat([xy, d[..., None], bias[..., None]], axis=-1)


input_dim = 4
output_dim = 1
hidden_dim = 16  # @param {type:"integer"}


act_fn = [
    lambda x: x,
    lambda x: -x,
    lambda x: jnp.tanh(x),
    lambda x: jnp.maximum(x, 0),
    lambda x: jnp.abs(x),
    lambda x: jnp.sin(x),
    lambda x: jnp.square(x),
]
act_fn_size = len(act_fn)


def apply_act_fn(data, ix):
    return jax.lax.switch(ix, act_fn, data)

# About vmap:
#   We use jax.vmap to automatically "vectorize" the function,
#   since data.shape = (batch_size, hidden_dim) and ix.shape = (hidden_dim,),
#   the vectorization should be applied to the 2nd and 1st dim of data and ix,
#   hence in_axes=(1, 0).
#   We want the output to be of shape (batch_size, hidden_dim),
#   therefore the vmap output_axes=1.
apply_act_func = jax.vmap(apply_act_fn, in_axes=(1, 0), out_axes=1)


@jax.jit
@jax.vmap
def graph_forward(params):
    """Simple implementation of a feedforward NEAT network."""

    # Apply activation to inputs.
    x = xydb
    ss = 0
    ee = ss + input_dim * act_fn_size
    act_ix = jnp.argmax(params[ss:ee].reshape(input_dim, act_fn_size), axis=-1)
    x = apply_act_func(x, act_ix)

    # Input to hidden.
    # Your code here

    # Hidden to hidden.
    # Your code here

    # Hidden to output.
    # Your code here

    return (255. / (1 + jnp.exp(-x))).squeeze(-1)


param_size = (
    act_fn_size * input_dim +
    (2 * input_dim) * hidden_dim +
    act_fn_size * hidden_dim +
    (2 * hidden_dim) * hidden_dim +
    act_fn_size * hidden_dim +
    (2 * hidden_dim) * output_dim
)
key = jax.random.PRNGKey(0)
rand_params = jax.random.normal(key, shape=(test_batch_size, param_size))
outputs = graph_forward(rand_params)
outputs = np.array(outputs).reshape([-1, image_size, image_size])
scores = gen_fitness(outputs)
print(f'param_size={param_size}\n')

if match_text:
    fig, axes = plt.subplots(
        1, test_batch_size, figsize=(3 * test_batch_size, 3))
    for i in range(test_batch_size):
        axes[i].imshow(np.uint8(np.round(outputs[i])), cmap='gray')
        axes[i].set_title(f'Fitness: {scores[i]:.4f}')
        axes[i].axis('off')
    plt.tight_layout()

In [None]:
# @title Train CPPN


pop_size = 128  # @param {type:"integer"}
init_stdev = 0.05  # @param {type:"number"}
seed = 42  # @param {type:"integer"}
total_gen = 500  # @param {type:"integer"}
log_interval = 50  # @param {type:"integer"}


algo = cma.CMAEvolutionStrategy(
    x0=np.zeros(param_size),
    sigma0=init_stdev,
    inopts={
        "popsize": pop_size,
        "seed": seed,
        "randn": np.random.randn,
    }
)

start_time = time.perf_counter()
for i in range(total_gen):
    population = algo.ask()
    outputs = graph_forward(jnp.array(population))
    outputs = np.array(outputs).reshape([-1, image_size, image_size])
    scores = gen_fitness(outputs)
    algo.tell(population, -scores)
    if i % log_interval == 0 or i + 1 == total_gen:
        t_cost = time.perf_counter() - start_time
        print(f'Gen {i}: score.max={np.max(scores):.4f}, ' +
              f'score.avg={np.mean(scores):.4f}, time={t_cost:.2f}s')
        start_time = time.perf_counter()

In [None]:
# @title Visualize the Image

best_param = jnp.array(algo.result.xfavorite)[None, ...]
outputs = graph_forward(best_param)
outputs = np.array(outputs).reshape([-1, image_size, image_size])

if match_text:
    img = Image.fromarray(np.uint8(np.round(outputs[0]))).resize((256, 256))
    print(f'"{prompt}"')
    display(img)
else:
    fig, axes = plt.subplots(1, 2, figsize=(6, 3))
    ax = axes[0]
    ax.imshow(np.uint8(np.round(outputs[0])), cmap='gray')
    ax.set_title('CPPN Image')
    ax.axis('off')
    ax = axes[1]
    ax.imshow(ref_img, cmap='gray')
    ax.set_title('Reference Image')
    ax.axis('off')
    plt.tight_layout()