Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Persistent topic models (NMF, LDA, Ensemble LDA) #84

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 56 additions & 27 deletions litstudy/nlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,43 +274,53 @@ def best_topic_for_documents(self) -> List[int]:
return np.argmax(self.doc2topic, axis=1)


def train_nmf_model(corpus: Corpus, num_topics: int, seed=0, max_iter=500) -> TopicModel:
"""Train a topic model using NMF.
def train_nmf_model(
corpus: Corpus, num_topics: int, seed=0, max_iter=500, filename=None
) -> TopicModel:
"""Train a topic model using NMF and save unless given file exists.

:param num_topics: The number of topics to train.
:param seed: The seed used for random number generation.
:param max_iter: The maximum number of iterations to use for training.
More iterations mean better results, but longer training
times.
:param filename: Name of gensim model to save, or to load if file exists.
"""
import gensim.models.nmf
from os.path import isfile

dic = corpus.dictionary
freqs = corpus.frequencies

tfidf = gensim.models.tfidfmodel.TfidfModel(dictionary=dic)
model = gensim.models.nmf.Nmf(
list(tfidf[freqs]),
num_topics=num_topics,
passes=max_iter,
random_state=seed,
w_stop_condition=1e-9,
h_stop_condition=1e-9,
w_max_iter=50,
h_max_iter=50,
)
if filename == None or isfile(filename) == False:
tfidf = gensim.models.tfidfmodel.TfidfModel(dictionary=dic)
model = gensim.models.nmf.Nmf(
list(tfidf[freqs]),
num_topics=num_topics,
passes=max_iter,
random_state=seed,
w_stop_condition=1e-9,
h_stop_condition=1e-9,
w_max_iter=50,
h_max_iter=50,
)
if filename != None:
model.save(filename)
else:
model = gensim.models.nmf.Nmf.load(filename)

doc2topic = corpus2dense(model[freqs], num_topics).T
topic2token = model.get_topics()

return TopicModel(dic, doc2topic, topic2token)


def train_lda_model(corpus: Corpus, num_topics, seed=0, **kwargs) -> TopicModel:
def train_lda_model(corpus: Corpus, num_topics, seed=0, filename=None, **kwargs) -> TopicModel:
"""Train a topic model using LDA.

:param num_topics: The number of topics to train.
:param seed: The seed used for random number generation.
:param filename: Name of gensim model to save, or to load if file exists.
:param kwargs: Arguments passed to `gensim.models.lda.LdaModel` (gensim3)
or `gensim.models.ldamodel.LdaModel` (gensim4).
"""
Expand All @@ -319,17 +329,26 @@ def train_lda_model(corpus: Corpus, num_topics, seed=0, **kwargs) -> TopicModel:
freqs = corpus.frequencies

from importlib.metadata import version
from os.path import isfile

gensim_mayor = int(version("gensim").split(".")[0])

if gensim_mayor == 3:
from gensim.models.lda import LdaModel

model = LdaModel(list(corpus), **kwargs)
if filename == None or isfile(filename) == False:
model = LdaModel(list(corpus), **kwargs)
if filename != None:
model.save(filename)
elif gensim_mayor == 4:
from gensim.models.ldamodel import LdaModel

model = LdaModel(freqs, id2word=dic, num_topics=num_topics, **kwargs)
if filename == None or isfile(filename) == False:
model = LdaModel(freqs, id2word=dic, num_topics=num_topics, **kwargs)
if filename != None:
model.save(filename)
else:
model = LdaModel.load(filename)

else:
from sys import exit

Expand All @@ -341,19 +360,22 @@ def train_lda_model(corpus: Corpus, num_topics, seed=0, **kwargs) -> TopicModel:
return TopicModel(dic, doc2topic, topic2token)


def train_elda_model(corpus: Corpus, num_topics, num_models=4, seed=0, **kwargs) -> TopicModel:
def train_elda_model(
corpus: Corpus, num_topics, num_models=4, seed=0, filename=None, **kwargs
) -> TopicModel:
"""Train a topic model using ensemble LDA.

:param num_topics: The number of topics to train.
:param num_models: The number of models to train.
:param seed: The seed used for random number generation.
:param filename: Name of gensim model to save, or to load if file exists.
:param kwargs: Arguments passed to `gensim.models.ensemblelda.EnsembleLda` (gensim4).
"""

from importlib.metadata import version
from os.path import isfile

gensim_mayor = int(version("gensim").split(".")[0])

if gensim_mayor <= 3:
from sys import exit

Expand All @@ -364,14 +386,21 @@ def train_elda_model(corpus: Corpus, num_topics, num_models=4, seed=0, **kwargs)

from gensim.models.ensemblelda import EnsembleLda

model = EnsembleLda(
topic_model_class="ldamulticore",
corpus=freqs,
id2word=dic,
num_topics=num_topics,
num_models=num_models,
**kwargs
)
if filename == None or isfile(filename) == False:
model = EnsembleLda(
topic_model_class="ldamulticore",
corpus=freqs,
id2word=dic,
num_topics=num_topics,
num_models=num_models,
**kwargs
)
if filename != None:
model.save(filename)
else:
model = EnsembleLda.load(filename)

model = model.generate_gensim_representation()

doc2topic = corpus2dense(model[freqs], num_topics).T
topic2token = model.get_topics()
Expand Down
Loading