diff --git a/src/render/scene.rs b/src/render/scene.rs index b8322d5..852d049 100644 --- a/src/render/scene.rs +++ b/src/render/scene.rs @@ -150,10 +150,18 @@ impl Scene { refracted(incident_ray, normal, diffraction_index, index).map_or_else( // Total reflection || self.reflection(point, 1., reflected, reflection_limit, diffraction_index), - // Refraction - |r| { - self.refraction(point, coef, r, reflection_limit, index) - + lighting * (1. - coef) + // Refraction (refracted ray, amount of *reflection*) + |(r, refl_t)| { + let refr_light = self.refraction(point, coef, r, reflection_limit, index) + * (1. - refl_t) + + self.reflection( + point, + refl_t, + reflected, + reflection_limit, + diffraction_index, + ) * refl_t; + refr_light * coef + lighting * (1. - coef) }, ) } @@ -197,7 +205,6 @@ impl Scene { reflection_limit: u32, diffraction_index: f32, ) -> LinearColor { - // FIXME: use fresnel reflection too if reflectivity > 1e-5 && reflection_limit > 0 { let reflection_start = point + reflected * 0.001; if let Some((t, obj)) = self.cast_ray(Ray::new(reflection_start, reflected)) { @@ -269,16 +276,23 @@ fn reflected(incident: Vector, normal: Vector) -> Vector { incident - delt } -fn refracted(incident: Vector, normal: Vector, n_1: f32, n_2: f32) -> Option { - let cos = incident.dot(&normal); - let normal = if cos < 0. { normal } else { -normal }; +/// Returns None if the ray was totally reflected, Some(refracted_ray, reflected_amount) if not +fn refracted(incident: Vector, normal: Vector, n_1: f32, n_2: f32) -> Option<(Vector, f32)> { + let cos1 = incident.dot(&normal); + let normal = if cos1 < 0. { normal } else { -normal }; let eta = n_1 / n_2; - let k = 1. - eta * eta * (1. - cos * cos); + let k = 1. - eta * eta * (1. - cos1 * cos1); if k < 0. { - None - } else { - Some(eta * incident + (eta * cos.abs() - f32::sqrt(k)) * normal) + return None; } + let cos1 = cos1.abs(); + let refracted = eta * incident + (eta * cos1 - f32::sqrt(k)) * normal; + let cos2 = -refracted.dot(&normal); // Take the negation because we're on the other side + let f_r = (n_2 * cos1 - n_1 * cos2) / (n_2 * cos1 + n_1 * cos2); + let f_t = (n_1 * cos2 - n_2 * cos1) / (n_1 * cos2 + n_2 * cos1); + let refl_t = (f_r * f_r + f_t * f_t) / 2.; + //Some((refracted, 0.)) + Some((refracted, refl_t)) } #[derive(Debug, PartialEq, Deserialize)]