Skip to content

Commit

Permalink
rebase: fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
drsk0 committed Jan 18, 2024
1 parent 177e137 commit 414e397
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 63 deletions.
66 changes: 8 additions & 58 deletions src/discretize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -613,31 +613,30 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem,
function get_likelihood_estimate_function(discretization::PhysicsInformedNN)
function full_loss_function(θ, p)
# the aggregation happens on cpu even if the losses are gpu, probably fine since it's only a few of them
pde_losses = [pde_loss_function(θ) for pde_loss_function in pde_loss_functions]
bc_losses = [bc_loss_function(θ) for bc_loss_function in bc_loss_functions]
# we need to type annotate the empty vector for autodiff to succeed in the case of empty equations/additional symbolic loss/boundary conditions.
pde_losses = num_pde_losses == 0 ? adaloss_T[] : [pde_loss_function(θ) for pde_loss_function in pde_loss_functions]
asl_losses = num_asl_losses == 0 ? adaloss_T[] : [asl_loss_function(θ) for asl_loss_function in asl_loss_functions]
bc_losses = num_bc_losses == 0 ? adaloss_T[] : [bc_loss_function(θ) for bc_loss_function in bc_loss_functions]

# this is kind of a hack, and means that whenever the outer function is evaluated the increment goes up, even if it's not being optimized
# that's why we prefer the user to maintain the increment in the outer loop callback during optimization
ChainRulesCore.@ignore_derivatives if self_increment
iteration[1] += 1
end
# the aggregation happens on cpu even if the losses are gpu, probably fine since it's only a few of them
# we need to type annotate the empty vector for autodiff to succeed in the case of empty equations/additional symbolic loss/boundary conditions.
pde_losses = num_pde_losses == 0 ? adaloss_T[] : [pde_loss_function(θ) for pde_loss_function in pde_loss_functions]
asl_losses = num_asl_losses == 0 ? adaloss_T[] : [asl_loss_function(θ) for asl_loss_function in asl_loss_functions]
bc_losses = num_bc_losses == 0 ? adaloss_T[] : [bc_loss_function(θ) for bc_loss_function in bc_loss_functions]

ChainRulesCore.@ignore_derivatives begin
reweight_losses_func(θ, pde_losses,
bc_losses)
asl_losses, bc_losses)
end

weighted_pde_losses = adaloss.pde_loss_weights .* pde_losses
weighted_asl_losses = adaloss.asl_loss_weights .* asl_losses
weighted_bc_losses = adaloss.bc_loss_weights .* bc_losses

sum_weighted_pde_losses = sum(weighted_pde_losses)
sum_weighted_asl_losses = sum(weighted_asl_losses)
sum_weighted_bc_losses = sum(weighted_bc_losses)
weighted_loss_before_additional = sum_weighted_pde_losses + sum_weighted_bc_losses
weighted_loss_before_additional = sum_weighted_pde_losses + sum_weighted_asl_losses + sum_weighted_bc_losses

full_weighted_loss = if additional_loss isa Nothing
weighted_loss_before_additional
Expand Down Expand Up @@ -694,21 +693,12 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem,
iteration[1])
end
end
ChainRulesCore.@ignore_derivatives begin reweight_losses_func(θ, pde_losses,
asl_losses, bc_losses) end

return full_weighted_loss
end
weighted_pde_losses = adaloss.pde_loss_weights .* pde_losses
weighted_asl_losses = adaloss.asl_loss_weights .* asl_losses
weighted_bc_losses = adaloss.bc_loss_weights .* bc_losses

return full_loss_function
end
sum_weighted_pde_losses = sum(weighted_pde_losses)
sum_weighted_asl_losses = sum(weighted_asl_losses)
sum_weighted_bc_losses = sum(weighted_bc_losses)
weighted_loss_before_additional = sum_weighted_pde_losses + sum_weighted_asl_losses + sum_weighted_bc_losses

function get_likelihood_estimate_function(discretization::BayesianPINN)
dataset_pde, dataset_bc = discretization.dataset
Expand Down Expand Up @@ -796,46 +786,6 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem,
end

return full_loss_function
ChainRulesCore.@ignore_derivatives begin if iteration[1] % log_frequency == 0
logvector(pinnrep.logger, pde_losses, "unweighted_loss/pde_losses",
iteration[1])
logvector(pinnrep.logger, asl_losses, "unweighted_loss/asl_losses",
iteration[1])
logvector(pinnrep.logger, bc_losses, "unweighted_loss/bc_losses", iteration[1])
logvector(pinnrep.logger, weighted_pde_losses,
"weighted_loss/weighted_pde_losses",
iteration[1])
logvector(pinnrep.logger, weighted_asl_losses,
"weighted_loss/weighted_asl_losses",
iteration[1])
logvector(pinnrep.logger, weighted_bc_losses,
"weighted_loss/weighted_bc_losses",
iteration[1])
if !(additional_loss isa Nothing)
logscalar(pinnrep.logger, weighted_additional_loss_val,
"weighted_loss/weighted_additional_loss", iteration[1])
end
logscalar(pinnrep.logger, sum_weighted_pde_losses,
"weighted_loss/sum_weighted_pde_losses", iteration[1])
logscalar(pinnrep.logger, sum_weighted_bc_losses,
"weighted_loss/sum_weighted_bc_losses", iteration[1])
logscalar(pinnrep.logger, sum_weighted_asl_losses,
"weighted_loss/sum_weighted_asl_losses", iteration[1])
logscalar(pinnrep.logger, full_weighted_loss,
"weighted_loss/full_weighted_loss",
iteration[1])
logvector(pinnrep.logger, adaloss.pde_loss_weights,
"adaptive_loss/pde_loss_weights",
iteration[1])
logvector(pinnrep.logger, adaloss.asl_loss_weights,
"adaptive_loss/asl_loss_weights",
iteration[1])
logvector(pinnrep.logger, adaloss.bc_loss_weights,
"adaptive_loss/bc_loss_weights",
iteration[1])
end end

return full_weighted_loss
end

full_loss_function = get_likelihood_estimate_function(discretization)
Expand Down
13 changes: 8 additions & 5 deletions src/pinn_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ struct PhysicsInformedNN{T, P, PH, DER, PE, AL, ASL, ADA, LOG, K} <: AbstractPIN
derivative = nothing,
param_estim = false,
additional_loss = nothing,
additional_symb_loss = [],
additional_symb_loss = [],
adaptive_loss = nothing,
logger = nothing,
log_options = LogOptions(),
Expand Down Expand Up @@ -136,14 +136,14 @@ struct PhysicsInformedNN{T, P, PH, DER, PE, AL, ASL, ADA, LOG, K} <: AbstractPIN

new{typeof(strategy), typeof(init_params), typeof(_phi), typeof(_derivative),
typeof(param_estim),
typeof(additional_loss), typeof(adaptive_loss), typeof(logger), typeof(kwargs)}(chain,
typeof(additional_loss), typeof(additional_symb_loss), typeof(adaptive_loss), typeof(logger), typeof(kwargs)}(chain,
strategy,
init_params,
_phi,
_derivative,
param_estim,
additional_loss,
additional_symb_loss,
additional_symb_loss,
adaptive_loss,
logger,
log_options,
Expand All @@ -162,6 +162,7 @@ BayesianPINN(chain,
phi = nothing,
param_estim = false,
additional_loss = nothing,
additional_symb_loss = nothing,
adaptive_loss = nothing,
logger = nothing,
log_options = LogOptions(),
Expand Down Expand Up @@ -211,14 +212,15 @@ methodology.
* `iteration`: used to control the iteration counter???
* `kwargs`: Extra keyword arguments.
"""
struct BayesianPINN{T, P, PH, DER, PE, AL, ADA, LOG, D, K} <: AbstractPINN
struct BayesianPINN{T, P, PH, DER, PE, AL, ASL, ADA, LOG, D, K} <: AbstractPINN
chain::Any
strategy::T
init_params::P
phi::PH
derivative::DER
param_estim::PE
additional_loss::AL
additional_symb_loss::ASL
adaptive_loss::ADA
logger::LOG
log_options::LogOptions
Expand All @@ -235,6 +237,7 @@ struct BayesianPINN{T, P, PH, DER, PE, AL, ADA, LOG, D, K} <: AbstractPINN
derivative = nothing,
param_estim = false,
additional_loss = nothing,
additional_symb_loss = nothing,
adaptive_loss = nothing,
logger = nothing,
log_options = LogOptions(),
Expand Down Expand Up @@ -272,7 +275,7 @@ struct BayesianPINN{T, P, PH, DER, PE, AL, ADA, LOG, D, K} <: AbstractPINN

new{typeof(strategy), typeof(init_params), typeof(_phi), typeof(_derivative),
typeof(param_estim),
typeof(additional_loss), typeof(adaptive_loss), typeof(logger), typeof(dataset),
typeof(additional_loss), typeof(additional_symb_loss), typeof(adaptive_loss), typeof(logger), typeof(dataset),
typeof(kwargs)}(chain,
strategy,
init_params,
Expand Down

0 comments on commit 414e397

Please sign in to comment.