feat: add silly markov thing
This commit is contained in:
parent
e2623702ec
commit
93abddc3be
3 changed files with 88 additions and 0 deletions
1
silly/markov/.gitignore
vendored
Normal file
1
silly/markov/.gitignore
vendored
Normal file
|
@ -0,0 +1 @@
|
||||||
|
data/
|
51
silly/markov/src/build_pairs.py
Normal file
51
silly/markov/src/build_pairs.py
Normal file
|
@ -0,0 +1,51 @@
|
||||||
|
from collections import defaultdict
|
||||||
|
import csv
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
|
||||||
|
def main():
|
||||||
|
if len(sys.argv) < 3:
|
||||||
|
print(f"Usage: {sys.argv[0]} <messages dir> <output file>")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
messages_path = sys.argv[1]
|
||||||
|
output_path = sys.argv[2]
|
||||||
|
|
||||||
|
index_path = os.path.join(messages_path, "index.json")
|
||||||
|
|
||||||
|
with open(index_path) as f:
|
||||||
|
message_index = json.load(f)
|
||||||
|
|
||||||
|
transitions = defaultdict(list)
|
||||||
|
|
||||||
|
messages = []
|
||||||
|
|
||||||
|
for id, channel_name in message_index.items():
|
||||||
|
print(f"Loading messages from {channel_name}...")
|
||||||
|
with open(os.path.join(messages_path, f"c{id}", "messages.csv"), newline='') as f:
|
||||||
|
reader = csv.DictReader(f)
|
||||||
|
for line in reader:
|
||||||
|
content = line["Contents"]
|
||||||
|
messages.append(content)
|
||||||
|
|
||||||
|
print(f"Loaded {len(messages)} messages!")
|
||||||
|
|
||||||
|
for message in messages:
|
||||||
|
words = re.split(' +', message)
|
||||||
|
for w0, w1, w2 in zip(words[0:], words[1:], words[2:]):
|
||||||
|
transitions[w0, w1].append(w2)
|
||||||
|
|
||||||
|
data = []
|
||||||
|
for state, transition in transitions.items():
|
||||||
|
data.append({
|
||||||
|
'state': state,
|
||||||
|
'transition': transition,
|
||||||
|
})
|
||||||
|
|
||||||
|
with open(output_path, "w") as f:
|
||||||
|
json.dump(data, f)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
36
silly/markov/src/generate.py
Normal file
36
silly/markov/src/generate.py
Normal file
|
@ -0,0 +1,36 @@
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
import sys
|
||||||
|
|
||||||
|
def main():
|
||||||
|
if len(sys.argv) < 2:
|
||||||
|
print("Usage: {sys.argv[0]} <chain file>")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
chain_path = sys.argv[1]
|
||||||
|
|
||||||
|
transitions = {}
|
||||||
|
with open(chain_path) as f:
|
||||||
|
data = json.load(f)
|
||||||
|
for item in data:
|
||||||
|
state = item['state']
|
||||||
|
transition = item['transition']
|
||||||
|
transitions[state[0], state[1]] = transition
|
||||||
|
|
||||||
|
start_states = list(transitions.keys())
|
||||||
|
|
||||||
|
print("Press enter to generate another")
|
||||||
|
while True:
|
||||||
|
w0, w1 = random.choice(start_states)
|
||||||
|
print(w0, w1, end='')
|
||||||
|
w2 = random.choice(transitions[w0, w1])
|
||||||
|
for _ in range(500):
|
||||||
|
print(' ' + w2, end='')
|
||||||
|
if (w1, w2) not in transitions:
|
||||||
|
break
|
||||||
|
w0, w1, w2 = w1, w2, random.choice(transitions[w1, w2])
|
||||||
|
print()
|
||||||
|
input()
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
Loading…
Reference in a new issue