import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, concatenate, Conv2DTranspose
# Define the U-Net network
def unet(input_size=(256, 256, 3)):
inputs = Input(input_size)
# Encoder
conv1 = Conv2D(64, 3, activation='relu', padding='same')(inputs)
conv1 = Conv2D(64, 3, activation='relu', padding='same')(conv1)
pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
conv2 = Conv2D(128, 3, activation='relu', padding='same')(pool1)
conv2 = Conv2D(128, 3, activation='relu', padding='same')(conv2)
pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
# Bottleneck
conv3 = Conv2D(256, 3, activation='relu', padding='same')(pool2)
conv3 = Conv2D(256, 3, activation='relu', padding='same')(conv3)
# Decoder
up4 = Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(conv3)
up4 = concatenate([up4, conv2], axis=3)
conv4 = Conv2D(128, 3, activation='relu', padding='same')(up4)
conv4 = Conv2D(128, 3, activation='relu', padding='same')(conv4)
up5 = Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(conv4)
up5 = concatenate([up5, conv1], axis=3)
conv5 = Conv2D(64, 3, activation='relu', padding='same')(up5)
conv5 = Conv2D(64, 3, activation='relu', padding='same')(conv5)
# Output
outputs = Conv2D(3, 1, activation='sigmoid')(conv5)
model = tf.keras.Model(inputs=inputs, outputs=outputs)
return model