最新要闻

广告

手机

iphone11大小尺寸是多少?苹果iPhone11和iPhone13的区别是什么?

iphone11大小尺寸是多少?苹果iPhone11和iPhone13的区别是什么?

警方通报辅警执法直播中被撞飞:犯罪嫌疑人已投案

警方通报辅警执法直播中被撞飞:犯罪嫌疑人已投案

家电

TF-GNN踩坑记录(四) 全球报道

来源:博客园


(资料图)

目录
  • 引言
    • 题外话(MapFeatures使用)
      • 节点特征变换
      • 边特征变换
      • 传入额外参数
  • 问题
    • 问题demo
  • 解决方案

引言

由于图数据结构问题,直接使用Tensorflow的一些层是无法直接处理图数据的,需要借用TF-GNN框架下的MapFeatures对图数据中的节点特征或是边特征进行变换。

题外话(MapFeatures使用)

节点特征变换

from tensorflow.keras.layers import BatchNormalizationfrom tensorflow_gnn.keras.layers import MapFeatures# map node featuresdef node_sets_fn(node_set, *, node_set_name):    features = node_set.features    return BatchNormalization()(features["hidden_state"])graph = MapFeatures(node_sets_fn=node_sets_fn)(graph)

边特征变换

from tensorflow_gnn.keras.layers import MapFeatures# Hashes edge features called "id", leaves others unchanged:def edge_sets_fn(edge_set, *, edge_set_name):    features = edge_set.get_features_dict()    ids = features.pop("id")    num_bins = 100_000 if edge_set_name == "views" else 20_000    hashed_ids = tf.keras.layers.Hashing(num_bins=num_bins)(ids)    features["hashed_id"] = hashed_ids    return featuresgraph = MapFeatures(edge_sets_fn=edge_sets_fn)(graph)

传入额外参数

from functools import partialfrom tensorflow.keras.layers import Densefrom tensorflow_gnn.keras.layers import MapFeatures# map node featuresdef node_sets_fn(node_set, *, node_set_name, dim):    features = node_set.features    return Dense(dim)(features["hidden_state"])graph = MapFeatures(node_sets_fn=partial(node_sets_fn, dim=64))(graph)

问题

就是在使用MapFeatures时,如果循环使用则会在存储模型的时候报错:ValueError: Unable to create dataset (name already exists)

问题demo

from functools import partialfrom tensorflow.keras.layers import Densefrom tensorflow_gnn.keras.layers import MapFeatures# map node featuresdef node_sets_fn(node_set, *, node_set_name, dim):    features = node_set.features    return Dense(dim)(features["hidden_state"])for ln in range(layer_num):    graph = MapFeatures(node_sets_fn=partial(node_sets_fn, dim=64))(graph)

解决方案

最后发现是在使用MapFeatures时,使用层时如Dense需要区分每一次变换时的层名

from functools import partialfrom tensorflow.keras.layers import Densefrom tensorflow_gnn.keras.layers import MapFeatures# map node featuresdef node_sets_fn(node_set, *, node_set_name, dim,name):    features = node_set.features    return Dense(dim, name=f"Dense_{name}")(features["hidden_state"])for ln in range(layer_num):    graph = MapFeatures(node_sets_fn=partial(node_sets_fn, dim=64,name=ln))(graph)

关键词: