Source code for easygraph.functions.graph_embedding.deepwalk

import random

from easygraph.functions.graph_embedding.node2vec import (
    _get_embedding_result_from_gensim_skipgram_model,
)
from easygraph.functions.graph_embedding.node2vec import learn_embeddings
from easygraph.utils import *
from tqdm import tqdm


__all__ = ["deepwalk"]


[docs] @not_implemented_for("multigraph") def deepwalk(G, dimensions=128, walk_length=80, num_walks=10, **skip_gram_params): """Graph embedding via DeepWalk. Parameters ---------- G : easygraph.Graph or easygraph.DiGraph dimensions : int Embedding dimensions, optional(default: 128) walk_length : int Number of nodes in each walk, optional(default: 80) num_walks : int Number of walks per node, optional(default: 10) skip_gram_params : dict Parameters for gensim.models.Word2Vec - do not supply `size`, it is taken from the `dimensions` parameter Returns ------- embedding_vector : dict The embedding vector of each node most_similar_nodes_of_node : dict The most similar nodes of each node and its similarity Examples -------- >>> deepwalk(G, ... dimensions=128, # The graph embedding dimensions. ... walk_length=80, # Walk length of each random walks. ... num_walks=10, # Number of random walks. ... skip_gram_params = dict( # The skip_gram parameters in Python package gensim. ... window=10, ... min_count=1, ... batch_words=4, ... iter=15 ... )) References ---------- .. [1] https://arxiv.org/abs/1403.6652 """ G_index, index_of_node, node_of_index = G.to_index_node_graph() walks = simulate_walks(G_index, walk_length=walk_length, num_walks=num_walks) model = learn_embeddings(walks=walks, dimensions=dimensions, **skip_gram_params) ( embedding_vector, most_similar_nodes_of_node, ) = _get_embedding_result_from_gensim_skipgram_model( G=G, index_of_node=index_of_node, node_of_index=node_of_index, model=model ) del G_index return embedding_vector, most_similar_nodes_of_node
def simulate_walks(G, walk_length, num_walks): walks = [] nodes = list(G.nodes) print("Walk iteration:") for walk_iter in tqdm(range(num_walks)): random.shuffle(nodes) for node in nodes: walks.append(_deepwalk_walk(G, walk_length=walk_length, start_node=node)) return walks def _deepwalk_walk(G, walk_length, start_node): """ Simulate a random walk starting from start node. """ walk = [start_node] while len(walk) < walk_length: cur = walk[-1] cur_nbrs = sorted(G.neighbors(cur)) if len(cur_nbrs) > 0: pick_node = random.choice(cur_nbrs) walk.append(pick_node) else: break return walk