MODEL#

Classes#

class models.zscl_utils.clip.model.AttentionPool2d(spacial_dim, embed_dim, num_heads, output_dim=None)[source]#

Bases: Module

forward(x)[source]#
class models.zscl_utils.clip.model.Bottleneck(inplanes, planes, stride=1)[source]#

Bases: Module

expansion = 4#
forward(x)[source]#
class models.zscl_utils.clip.model.CLIP(embed_dim, image_resolution, vision_layers, vision_width, vision_patch_size, context_length, vocab_size, transformer_width, transformer_heads, transformer_layers, baseline=False)[source]#

Bases: Module

build_attention_mask()[source]#
property dtype#
encode_image(image)[source]#
encode_text(text)[source]#
forward(image, text)[source]#
initialize_parameters()[source]#
class models.zscl_utils.clip.model.LayerNorm(normalized_shape, eps=1e-05, elementwise_affine=True, bias=True, device=None, dtype=None)[source]#

Bases: LayerNorm

Subclass torch’s LayerNorm to handle fp16.

forward(x)[source]#
class models.zscl_utils.clip.model.ModifiedResNet(layers, output_dim, heads, input_resolution=224, width=64)[source]#

Bases: Module

A ResNet class that is similar to torchvision’s but contains the following changes: - There are now 3 “stem” convolutions as opposed to 1, with an average pool instead of a max pool. - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 - The final pooling layer is a QKV attention instead of an average pool

forward(x)[source]#
class models.zscl_utils.clip.model.QuickGELU(*args, **kwargs)[source]#

Bases: Module

forward(x)[source]#
class models.zscl_utils.clip.model.ResidualAttentionBlock(d_model, n_head, attn_mask=None)[source]#

Bases: Module

attention(x)[source]#
forward(x)[source]#
class models.zscl_utils.clip.model.Transformer(width, layers, heads, attn_mask=None)[source]#

Bases: Module

forward(x)[source]#
class models.zscl_utils.clip.model.VisualTransformer(input_resolution, patch_size, width, layers, heads, output_dim)[source]#

Bases: Module

forward(x)[source]#

Functions#

models.zscl_utils.clip.model.build_model(state_dict)[source]#
models.zscl_utils.clip.model.convert_weights(model)[source]#

Convert applicable model parameters to fp16