Last active
August 23, 2016 18:14
-
-
Save loristns/8a09da2ad0e0ab0fcecece7a2aecc594 to your computer and use it in GitHub Desktop.
Markov using GraphInPy
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import Counter | |
def append_a_graph(graph_dict: dict, key: str, value) -> dict: | |
""" | |
Append a value into a graph | |
:param graph_dict: the Graph.dict value | |
:param key: the name of the point | |
:param value: the name of the point to add | |
:type value: set or list | |
:return: a new Graph.dict value | |
""" | |
if isinstance(value, set): | |
graph_dict[key].update(value) | |
return graph_dict | |
elif isinstance(value, list): | |
graph_dict[key] += value | |
return graph_dict | |
else: | |
raise TypeError('"value" argument must be a set or a list') | |
class Graph(object): | |
"""A simple graph object, with a simple API based on default Python's methods.""" | |
def __init__(self, graph_dict=None, connection_memory=set()): | |
if graph_dict is None: | |
graph_dict = dict() | |
self.dict = graph_dict | |
self.connection_memory = connection_memory # Support list and set | |
def __setitem__(self, key: str, value): | |
""" | |
Add a point in a graph and specify and define what are the others points it is connected to. | |
:param key: Define the name of the point. | |
:param value: List of other point it is connected to. | |
:type value: set or list | |
""" | |
if key in self.dict.keys(): | |
append_a_graph(self.dict, key, value) | |
else: # If the key doesn't exist, create a valid one and add the value | |
if isinstance(self.connection_memory, list): | |
self.dict[key] = list() | |
elif isinstance(self.connection_memory, set): | |
self.dict[key] = set() | |
append_a_graph(self.dict, key, value) | |
def __delitem__(self, key): | |
del self.dict[key] | |
def __getitem__(self, item): | |
return self.dict[item] | |
@property | |
def weights(self): | |
""" | |
Count number of connections for each points | |
:return: A dict who contain data | |
""" | |
listed_values = [list(value) for value in self.dict.values()] | |
merged_values = sum(listed_values, list()) | |
counted_values = dict(Counter(merged_values)) | |
for key in self.dict.keys(): # Add keys who don't have any connections | |
if key not in counted_values: | |
counted_values[key] = 0 | |
return counted_values | |
def count_connection(self, word_to_analyse, word_to_count): | |
""" | |
A method for counting connection from a word to another word | |
:return: If you choose list for connection support it return an integer, else (if you choose set) a boolean | |
""" | |
if isinstance(self.connection_memory, list): | |
return self.dict[word_to_analyse].count(word_to_count) | |
elif isinstance(self.connection_memory, set): | |
return word_to_count in self.dict[word_to_analyse] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from graph import Graph | |
import random | |
def random_pick(some_list, probabilities): | |
x = random.uniform(0, 1) | |
cumulative_probability = 0.0 | |
for item, item_probability in zip(some_list, probabilities): | |
cumulative_probability += item_probability | |
if x < cumulative_probability: break | |
return item | |
def count_graph_next_probability(graph, word: str): | |
return_dict = {} | |
for possibility in graph.dict[word]: | |
try: | |
return_dict[possibility] = graph.count_connection(word_to_analyse=word, word_to_count=possibility) / len(graph.dict[word]) | |
except ZeroDivisionError: | |
return_dict[possibility] = 1 | |
return return_dict | |
class MarkovGenerator(object): | |
def __init__(self, image=dict()): | |
self.graph = Graph(graph_dict=image, connection_memory=list()) | |
def train(self, object_list: list): | |
object_list.insert(0, '[START]') | |
object_list.append('[END]') | |
for index, objects in enumerate(object_list): | |
try: | |
self.graph[objects] = [object_list[index+1]] | |
except IndexError: | |
self.graph[objects] = [] | |
def generator(self, char_limit: int): | |
sentence = "" | |
word_to_start = '[START]' | |
while len(sentence) < char_limit: | |
result = count_graph_next_probability(self.graph, word_to_start) | |
if bool(result): | |
print(result) | |
word_to_start = random_pick(result.keys(), list(result.values())) | |
sentence += " {}".format(word_to_start) | |
else: | |
break | |
return sentence | |
tweets = [] | |
dzed = MarkovGenerator() | |
for tweet in tweets: | |
dzed.train(tweet.split(" ")) | |
print(dzed.graph.dict) | |
print(dzed.generator(140)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment