Skip to content

Instantly share code, notes, and snippets.

@d0choa
Last active April 14, 2022 14:01
Show Gist options
  • Save d0choa/a889634d76d26b33bea09731c63c3eec to your computer and use it in GitHub Desktop.
Save d0choa/a889634d76d26b33bea09731c63c3eec to your computer and use it in GitHub Desktop.
"""
Compute all vs all Bayesian colocalisation analysis for all Genetics Portal
This script calculates posterior probabilities of different causal variants
configurations under the assumption of a single causal variant for each trait.
Logic reproduced from: https://github.com/chr1swallace/coloc/blob/main/R/claudia.R
"""
from functools import reduce
from pyspark import SparkConf
from pyspark.sql import SparkSession
from pyspark.ml.linalg import VectorUDT, Vectors
import numpy as np
import pyspark.sql.functions as F
import pyspark.sql.types as T
import pyspark.ml.functions as Fml
import hydra
from omegaconf import DictConfig
# CREDSET22PATH = (
# "gs://genetics-portal-dev-analysis/dsuveges/test_credible_set_chr22.parquet"
# )
# # CREDSETPATH = (
# # "gs://genetics-portal-dev-staging/finemapping/220228_merged/credset/"
# # )
# PHENOTYPEIDGENEPATH = "gs://ot-team/dochoa/phenotype_id_gene_luts/"
# INSUMSTATSPATH = (
# "gs://genetics-portal-dev-sumstats/filtered/significant_window_2mb_union"
# )
# OUTPUT = "gs://ot-team/dochoa/coloc_ml_test.parquet"
def getLogsum(logABF: VectorUDT):
"""
This function calculates the log of the sum of the exponentiated
logs taking out the max, i.e. insuring that the sum is not Inf
"""
themax = np.max(logABF)
result = themax + np.log(np.sum(np.exp(logABF - themax)))
return float(result)
logsum = F.udf(getLogsum, T.DoubleType())
@F.udf(returnType=VectorUDT())
def posteriors(allAbfs: VectorUDT):
"""
Calculates the posterior probability of each hypothesis given the evidence.
"""
diff = allAbfs - getLogsum(allAbfs)
abfsPosteriors = np.exp(diff)
return Vectors.dense(abfsPosteriors)
@hydra.main(config_path="config", config_name="config")
def main(cfg: DictConfig):
"""
Run colocalisation analysis
"""
sparkConf = SparkConf()
sparkConf = sparkConf.set("spark.hadoop.fs.gs.requester.pays.mode", "AUTO")
sparkConf = sparkConf.set(
"spark.hadoop.fs.gs.requester.pays.project.id", "open-targets-eu-dev"
)
# establish spark connection
spark = SparkSession.builder.config(conf=sparkConf).master("local[*]").getOrCreate()
credSet = (
spark.read.parquet(cfg.coloc.credible_set)
.distinct()
.withColumn(
"studyKey",
F.concat_ws("_", *["type", "study_id", "phenotype_id", "bio_feature"]),
)
)
# Priors
# priorc1 Prior on variant being causal for trait 1
# priorc2 Prior on variant being causal for trait 2
# priorc12 Prior on variant being causal for traits 1 and 2
priors = spark.createDataFrame(
[(1e-4, 1e-4, 1e-5)], ("priorc1", "priorc2", "priorc12")
)
# TODO: calculate logABF from data, because Finngen studies (Sumstats) don't have logABF
# https://github.com/tobyjohnson/gtx/blob/9afa9597a51d0ff44536bc5c8eddd901ab3e867c/R/abf.R#L53
columnsToJoin = ["studyKey", "tag_variant_id", "lead_variant_id", "type", "logABF"]
renameColumns = ["studyKey", "lead_variant_id", "type", "logABF"]
# Overlapping signals (exploded at the tag variant level)
leftDf = reduce(
lambda DF, col: DF.withColumnRenamed(col, "left_" + col),
renameColumns,
credSet.select(columnsToJoin).distinct(),
)
rightDf = reduce(
lambda DF, col: DF.withColumnRenamed(col, "right_" + col),
renameColumns,
credSet.select(columnsToJoin).distinct(),
)
overlappingPeaks = (
leftDf
# molecular traits always on the right-side
.filter(F.col("left_type") == "gwas")
# Get all study/peak pairs where at least one tagging variant overlap:
.join(rightDf, on="tag_variant_id", how="inner")
.filter(
# Remove rows with identical study:
(F.col("left_studyKey") != F.col("right_studyKey"))
)
# Keep only the upper triangle where both study is gwas
.filter(
(F.col("right_type") != "gwas")
| (F.col("left_studyKey") > F.col("right_studyKey"))
)
# remove overlapping tag variant isnfo
.drop("left_logABF", "right_logABF", "tag_variant_id")
# distinct to get study-pair info
.distinct()
.persist()
)
overlappingLeft = overlappingPeaks.join(
leftDf.select(
"left_studyKey", "left_lead_variant_id", "tag_variant_id", "left_logABF"
),
on=["left_studyKey", "left_lead_variant_id"],
how="inner",
)
overlappingRight = overlappingPeaks.join(
rightDf.select(
"right_studyKey", "right_lead_variant_id", "tag_variant_id", "right_logABF"
),
on=["right_studyKey", "right_lead_variant_id"],
how="inner",
)
overlappingSignals = overlappingLeft.alias("a").join(
overlappingRight.alias("b"),
on=[
"tag_variant_id",
"left_lead_variant_id",
"right_lead_variant_id",
"left_studyKey",
"right_studyKey",
"right_type",
"left_type",
],
how="outer",
)
signalPairsCols = ["studyKey", "lead_variant_id", "type"]
coloc = (
overlappingSignals
# Before summarizing logABF columns nulls need to be filled with 0:
.fillna(0, subset=["left_logABF", "right_logABF"])
# Grouping data by peak and collect list of the sums:
.withColumn("sum_logABF", F.col("left_logABF") + F.col("right_logABF"))
# Group by overlapping peak and generating dense vectors of logABF:
.groupBy(
*["left_" + col for col in signalPairsCols]
+ ["right_" + col for col in signalPairsCols]
)
.agg(
F.count("*").alias("coloc_n_vars"),
Fml.array_to_vector(F.collect_list(F.col("left_logABF"))).alias(
"left_logABF"
),
Fml.array_to_vector(F.collect_list(F.col("right_logABF"))).alias(
"right_logABF"
),
Fml.array_to_vector(F.collect_list(F.col("sum_logABF"))).alias(
"sum_logABF"
),
)
# Log sums
.withColumn("logsum1", logsum(F.col("left_logABF")))
.withColumn("logsum2", logsum(F.col("right_logABF")))
.withColumn("logsum12", logsum(F.col("sum_logABF")))
.drop("left_logABF", "right_logABF", "sum_logABF")
#
# Add priors
.crossJoin(priors)
# h0-h2
.withColumn("lH0abf", F.lit(0))
.withColumn("lH1abf", F.log(F.col("priorc1")) + F.col("logsum1"))
.withColumn("lH2abf", F.log(F.col("priorc2")) + F.col("logsum2"))
# h3
.withColumn("sumlogsum", F.col("logsum1") + F.col("logsum2"))
# exclude null H3/H4s: due to sumlogsum == logsum12
.filter(F.col("sumlogsum") != F.col("logsum12"))
.withColumn("max", F.greatest("sumlogsum", "logsum12"))
.withColumn(
"logdiff",
(
F.col("max")
+ F.log(
F.exp(F.col("sumlogsum") - F.col("max"))
- F.exp(F.col("logsum12") - F.col("max"))
)
),
)
.withColumn(
"lH3abf",
F.log(F.col("priorc1")) + F.log(F.col("priorc2")) + F.col("logdiff"),
)
.drop("right_logsum", "left_logsum", "sumlogsum", "max", "logdiff")
# h4
.withColumn("lH4abf", F.log(F.col("priorc12")) + F.col("logsum12"))
# cleaning
.drop("priorc1", "priorc2", "priorc12", "logsum1", "logsum2", "logsum12")
# posteriors
.withColumn(
"allABF",
Fml.array_to_vector(
F.array(
F.col("lH0abf"),
F.col("lH1abf"),
F.col("lH2abf"),
F.col("lH3abf"),
F.col("lH4abf"),
)
),
)
.withColumn("posteriors", Fml.vector_to_array(posteriors(F.col("allABF"))))
.withColumn("coloc_h0", F.col("posteriors").getItem(0))
.withColumn("coloc_h1", F.col("posteriors").getItem(1))
.withColumn("coloc_h2", F.col("posteriors").getItem(2))
.withColumn("coloc_h3", F.col("posteriors").getItem(3))
.withColumn("coloc_h4", F.col("posteriors").getItem(4))
.withColumn("coloc_h4_h3", F.col("coloc_h4") / F.col("coloc_h3"))
.withColumn("coloc_log2_h4_h3", F.log2(F.col("coloc_h4_h3")))
# clean up
.drop("posteriors", "allABF", "lH0abf", "lH1abf", "lH2abf", "lH3abf", "lH4abf")
)
phenotypeIdGene = (
spark.read.option("header", "true")
.option("sep", "\t")
.csv(cfg.coloc.phenotype_id_gene)
)
# Adding study, variant and study-variant metadata from credible set
credSetStudyMeta = credSet.select(
"studyKey",
F.col("study_id").alias("study"),
"bio_feature",
F.col("phenotype_id").alias("phenotype"),
).distinct()
credSetVariantMeta = credSet.select(
F.col("lead_variant_id"),
F.col("lead_chrom").alias("chrom"),
F.col("lead_pos").alias("pos"),
F.col("lead_ref").alias("ref"),
F.col("lead_alt").alias("alt"),
).distinct()
sumstatsLeftVarRightStudyInfo = (
spark.read.parquet(cfg.coloc.sumstats_filtered)
.withColumn(
"right_studyKey",
F.concat_ws("_", *["type", "study_id", "phenotype_id", "bio_feature"]),
)
.withColumn(
"left_variantId",
F.concat_ws("_", F.col("chrom"), F.col("pos"), F.col("ref"), F.col("alt")),
)
.withColumnRenamed("beta", "left_var_right_study_beta")
.withColumnRenamed("se", "left_var_right_study_se")
.withColumnRenamed("pval", "left_var_right_study_pval")
.withColumnRenamed("is_cc", "left_var_right_isCC")
# Only keep required columns
.select(
"left_variantId",
"right_studyKey",
"left_var_right_study_beta",
"left_var_right_study_se",
"left_var_right_study_pval",
"left_var_right_isCC",
)
)
colocWithMetadata = (
coloc.join(
reduce(
lambda DF, col: DF.withColumnRenamed(col, "left_" + col),
credSetStudyMeta.columns,
credSetStudyMeta,
),
on="left_studyKey",
how="left",
)
.join(
reduce(
lambda DF, col: DF.withColumnRenamed(col, "right_" + col),
credSetStudyMeta.columns,
credSetStudyMeta,
),
on="right_studyKey",
how="left",
)
.drop("left_studyKey", "right_studyKey")
.join(
reduce(
lambda DF, col: DF.withColumnRenamed(col, "left_" + col),
credSetVariantMeta.columns,
credSetVariantMeta,
),
on="left_lead_variant_id",
how="left",
)
.join(
reduce(
lambda DF, col: DF.withColumnRenamed(col, "right_" + col),
credSetVariantMeta.columns,
credSetVariantMeta,
),
on="right_lead_variant_id",
how="left",
)
.join(
sumstatsLeftVarRightStudyInfo,
on=["left_variant_id", "right_studyKey"],
how="left",
)
.drop("left_lead_variant_id", "right_lead_variant_id")
.drop("left_bio_feature", "left_phenotype")
.join(
phenotypeIdGene.select(
F.col("phenotype_id").alias("right_phenotype"),
F.col("gene_id").alias("right_gene_id"),
),
on="right_phenotype",
how="left",
)
.withColumn(
"right_gene_id",
F.when(
F.col("right_phenotype").startswith("ENSG"), F.col("right_phenotype")
).otherwise(F.col("right_gene_id")),
)
.withColumn(
"right_gene_id",
F.when(
F.col("right_study") == "GTEx-sQTL",
F.regexp_extract(F.col("right_phenotype"), ":(ENSG.*)$", 1),
).otherwise(F.col("right_gene_id")),
)
)
# Pending
# |-- is_flipped: boolean (nullable = true)
# |-- right_gene_id: string (nullable = true)
# |-- left_var_right_study_beta: double (nullable = true)
# |-- left_var_right_study_se: double (nullable = true)
# |-- left_var_right_study_pval: double (nullable = true)
# |-- left_var_right_isCC: boolean (nullable = true)
(
colocWithMetadata.write.partitionBy("left_chrom")
.mode("overwrite")
.parquet(cfg.coloc.output)
)
# TODO: compute model averaged effect size ratios
# https://github.com/tobyjohnson/gtx/blob/9afa9597a51d0ff44536bc5c8eddd901ab3e867c/R/coloc.R#L91
# For debugging
# (
# coloc
# .filter(
# (F.col("left_studyKey") == "gwas_NEALE2_20003_1140909872") &
# (F.col("right_studyKey") ==
# "sqtl_GTEx-sQTL_chr22:17791301:17806239:clu_21824:ENSG00000243156_Ovary") &
# (F.col("left_lead_variant_id") == "22:16590692:CAA:C") &
# (F.col("right_lead_variant_id") == "22:17806438:G:A"))
# .show(vertical = True)
# )
return 0
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment